Skip to content

Instantly share code, notes, and snippets.

@piiswrong
Created January 25, 2017 21:30
Show Gist options
  • Save piiswrong/deeb345ccc5c83426d2249ba86105cfc to your computer and use it in GitHub Desktop.
Save piiswrong/deeb345ccc5c83426d2249ba86105cfc to your computer and use it in GitHub Desktop.
This file has been truncated, but you can view the full file.
#if defined(__MACH__)
#include <mach/clock.h>
#include <mach/mach.h>
#endif
#if !defined(__WIN32__)
#include <sys/stat.h>
#include <sys/types.h>
#if !defined(__ANDROID__) && (!defined(MSHADOW_USE_SSE) || MSHADOW_USE_SSE == 1)
#include <emmintrin.h>
#endif
#endif
#include <algorithm>
#include <array>
#include <assert.h>
#include <atomic>
#include <cblas.h>
#include <cctype>
#include <cfloat>
#include <chrono>
#include <climits>
#include <cmath>
#include <condition_variable>
#include <cstddef>
#include <cstdint>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <ctime>
#include <deque>
#include <dirent.h>
#include <emmintrin.h>
#include <errno.h>
#include <execinfo.h>
#include <fstream>
#include <functional>
#include <inttypes.h>
#include <iostream>
#include <istream>
#include <limits>
#include <list>
#include <map>
#include <memory>
#include <mutex>
#include <new>
#include <ostream>
#include <queue>
#include <random>
#include <regex>
#include <sched.h>
#include <set>
#include <sstream>
#include <stdbool.h>
#include <stddef.h>
#include <stdexcept>
#include <stdint.h>
#include <stdlib.h>
#include <streambuf>
#include <string>
#include <thread>
#include <time.h>
#include <tuple>
#include <type_traits>
#include <typeindex>
#include <typeinfo>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
//===== EXPANDING: dmlc-minimum0.cc =====
/*!
* Copyright 2015 by Contributors.
* \brief Mininum DMLC library Amalgamation, used for easy plugin of dmlc lib.
* Normally this is not needed.
*/
//===== EXPANDING: ../dmlc-core/src/io/line_split.cc =====
// Copyright by Contributors
//===== EXPANDING: ../dmlc-core/include/dmlc/io.h =====
/*!
* Copyright (c) 2015 by Contributors
* \file io.h
* \brief defines serializable interface of dmlc
*/
#ifndef DMLC_IO_H_
#define DMLC_IO_H_
//===== EXPANDING: ../dmlc-core/include/dmlc/logging.h =====
/*!
* Copyright (c) 2015 by Contributors
* \file logging.h
* \brief defines logging macros of dmlc
* allows use of GLOG, fall back to internal
* implementation when disabled
*/
#ifndef DMLC_LOGGING_H_
#define DMLC_LOGGING_H_
//===== EXPANDING: ../dmlc-core/include/dmlc/base.h =====
/*!
* Copyright (c) 2015 by Contributors
* \file base.h
* \brief defines configuration macros
*/
#ifndef DMLC_BASE_H_
#define DMLC_BASE_H_
/*! \brief whether use glog for logging */
#ifndef DMLC_USE_GLOG
#define DMLC_USE_GLOG 0
#endif
/*!
* \brief whether throw dmlc::Error instead of
* directly calling abort when FATAL error occured
* NOTE: this may still not be perfect.
* do not use FATAL and CHECK in destructors
*/
#ifndef DMLC_LOG_FATAL_THROW
#define DMLC_LOG_FATAL_THROW 1
#endif
/*!
* \brief whether always log a message before throw
* This can help identify the error that cannot be catched.
*/
#ifndef DMLC_LOG_BEFORE_THROW
#define DMLC_LOG_BEFORE_THROW 1
#endif
/*!
* \brief Whether to use customized logger,
* whose output can be decided by other libraries.
*/
#ifndef DMLC_LOG_CUSTOMIZE
#define DMLC_LOG_CUSTOMIZE 0
#endif
/*!
* \brief Wheter to print stack trace for fatal error,
* enabled on linux when using gcc.
*/
#if (!defined(DMLC_LOG_STACK_TRACE) && defined(__GNUC__))
#define DMLC_LOG_STACK_TRACE 1
#endif
/*! \brief whether compile with hdfs support */
#ifndef DMLC_USE_HDFS
#define DMLC_USE_HDFS 0
#endif
/*! \brief whether compile with s3 support */
#ifndef DMLC_USE_S3
#define DMLC_USE_S3 0
#endif
/*! \brief whether or not use parameter server */
#ifndef DMLC_USE_PS
#define DMLC_USE_PS 0
#endif
/*! \brief whether or not use c++11 support */
#ifndef DMLC_USE_CXX11
#define DMLC_USE_CXX11 (defined(__GXX_EXPERIMENTAL_CXX0X__) ||\
__cplusplus >= 201103L || defined(_MSC_VER))
#endif
/*! \brief strict CXX11 support */
#ifndef DMLC_STRICT_CXX11
#define DMLC_STRICT_CXX11 (__cplusplus >= 201103L || defined(_MSC_VER))
#endif
/*! \brief whether RTTI is enabled */
#ifndef DMLC_ENABLE_RTTI
#define DMLC_ENABLE_RTTI 1
#endif
/// check if g++ is before 4.6
#if DMLC_USE_CXX11 && defined(__GNUC__) && !defined(__clang_version__)
#if __GNUC__ == 4 && __GNUC_MINOR__ < 6
#pragma message("Will need g++-4.6 or higher to compile all" \
"the features in dmlc-core, " \
"compile without c++0x, some features may be disabled")
#undef DMLC_USE_CXX11
#define DMLC_USE_CXX11 0
#endif
#endif
/*!
* \brief Enable std::thread related modules,
* Used to disable some module in mingw compile.
*/
#ifndef DMLC_ENABLE_STD_THREAD
#define DMLC_ENABLE_STD_THREAD DMLC_USE_CXX11
#endif
/*! \brief whether enable regex support, actually need g++-4.9 or higher*/
#ifndef DMLC_USE_REGEX
#define DMLC_USE_REGEX (__cplusplus >= 201103L || defined(_MSC_VER))
#endif
/*! \brief helper macro to supress unused warning */
#if defined(__GNUC__)
#define DMLC_ATTRIBUTE_UNUSED __attribute__((unused))
#else
#define DMLC_ATTRIBUTE_UNUSED
#endif
/*! \brief helper macro to generate string concat */
#define DMLC_STR_CONCAT_(__x, __y) __x##__y
#define DMLC_STR_CONCAT(__x, __y) DMLC_STR_CONCAT_(__x, __y)
/*!
* \brief Disable copy constructor and assignment operator.
*
* If C++11 is supported, both copy and move constructors and
* assignment operators are deleted explicitly. Otherwise, they are
* only declared but not implemented. Place this macro in private
* section if C++11 is not available.
*/
#ifndef DISALLOW_COPY_AND_ASSIGN
# if DMLC_USE_CXX11
# define DISALLOW_COPY_AND_ASSIGN(T) \
T(T const&) = delete; \
T(T&&) = delete; \
T& operator=(T const&) = delete; \
T& operator=(T&&) = delete
# else
# define DISALLOW_COPY_AND_ASSIGN(T) \
T(T const&); \
T& operator=(T const&)
# endif
#endif
///
/// code block to handle optionally loading
///
#if !defined(__GNUC__)
#define fopen64 std::fopen
#endif
#if (defined __MINGW32__) && !(defined __MINGW64__)
#define fopen64 std::fopen
#endif
#ifdef _MSC_VER
#if _MSC_VER < 1900
// NOTE: sprintf_s is not equivalent to snprintf,
// they are equivalent when success, which is sufficient for our case
#define snprintf sprintf_s
#define vsnprintf vsprintf_s
#endif
#else
#ifdef _FILE_OFFSET_BITS
#if _FILE_OFFSET_BITS == 32
#pragma message("Warning: FILE OFFSET BITS defined to be 32 bit")
#endif
#endif
#ifdef __APPLE__
#define off64_t off_t
#define fopen64 std::fopen
#endif
extern "C" {
}
#endif
#ifdef _MSC_VER
//! \cond Doxygen_Suppress
typedef signed char int8_t;
typedef __int16 int16_t;
typedef __int32 int32_t;
typedef __int64 int64_t;
typedef unsigned char uint8_t;
typedef unsigned __int16 uint16_t;
typedef unsigned __int32 uint32_t;
typedef unsigned __int64 uint64_t;
//! \endcond
#else
#endif
#if defined(_MSC_VER) && _MSC_VER < 1900
#define noexcept_true throw ()
#define noexcept_false
#define noexcept(a) noexcept_##a
#endif
#if DMLC_USE_CXX11
#define DMLC_THROW_EXCEPTION noexcept(false)
#define DMLC_NO_EXCEPTION noexcept(true)
#else
#define DMLC_THROW_EXCEPTION
#define DMLC_NO_EXCEPTION
#endif
/*! \brief namespace for dmlc */
namespace dmlc {
/*!
* \brief safely get the beginning address of a vector
* \param vec input vector
* \return beginning address of a vector
*/
template<typename T>
inline T *BeginPtr(std::vector<T> &vec) { // NOLINT(*)
if (vec.size() == 0) {
return NULL;
} else {
return &vec[0];
}
}
/*!
* \brief get the beginning address of a const vector
* \param vec input vector
* \return beginning address of a vector
*/
template<typename T>
inline const T *BeginPtr(const std::vector<T> &vec) {
if (vec.size() == 0) {
return NULL;
} else {
return &vec[0];
}
}
/*!
* \brief get the beginning address of a string
* \param str input string
* \return beginning address of a string
*/
inline char* BeginPtr(std::string &str) { // NOLINT(*)
if (str.length() == 0) return NULL;
return &str[0];
}
/*!
* \brief get the beginning address of a const string
* \param str input string
* \return beginning address of a string
*/
inline const char* BeginPtr(const std::string &str) {
if (str.length() == 0) return NULL;
return &str[0];
}
} // namespace dmlc
#if defined(_MSC_VER) && _MSC_VER < 1900
#define constexpr const
#define alignof __alignof
#endif
#endif // DMLC_BASE_H_
//===== EXPANDED: ../dmlc-core/include/dmlc/base.h =====
#if DMLC_LOG_STACK_TRACE
#endif
namespace dmlc {
/*!
* \brief exception class that will be thrown by
* default logger if DMLC_LOG_FATAL_THROW == 1
*/
struct Error : public std::runtime_error {
/*!
* \brief constructor
* \param s the error message
*/
explicit Error(const std::string &s) : std::runtime_error(s) {}
};
} // namespace dmlc
#if DMLC_USE_GLOG
namespace dmlc {
/*!
* \brief optionally redirect to google's init log
* \param argv0 The arguments.
*/
inline void InitLogging(const char* argv0) {
google::InitGoogleLogging(argv0);
}
} // namespace dmlc
#else
// use a light version of glog
#if defined(_MSC_VER)
#pragma warning(disable : 4722)
#endif
namespace dmlc {
inline void InitLogging(const char* argv0) {
// DO NOTHING
}
class LogCheckError {
public:
LogCheckError() : str(nullptr) {}
explicit LogCheckError(const std::string& str_) : str(new std::string(str_)) {}
~LogCheckError() { if (str != nullptr) delete str; }
operator bool() {return str != nullptr; }
std::string* str;
};
#define DEFINE_CHECK_FUNC(name, op) \
template <typename X, typename Y> \
inline LogCheckError LogCheck##name(const X& x, const Y& y) { \
if (x op y) return LogCheckError(); \
std::ostringstream os; \
os << " (" << x << " vs. " << y << ") "; /* CHECK_XX(x, y) requires x and y can be serialized to string. Use CHECK(x OP y) otherwise. NOLINT(*) */ \
return LogCheckError(os.str()); \
} \
inline LogCheckError LogCheck##name(int x, int y) { \
return LogCheck##name<int, int>(x, y); \
}
#define CHECK_BINARY_OP(name, op, x, y) \
if (dmlc::LogCheckError _check_err = dmlc::LogCheck##name(x, y)) \
dmlc::LogMessageFatal(__FILE__, __LINE__).stream() \
<< "Check failed: " << #x " " #op " " #y << *(_check_err.str)
DEFINE_CHECK_FUNC(_LT, <)
DEFINE_CHECK_FUNC(_GT, >)
DEFINE_CHECK_FUNC(_LE, <=)
DEFINE_CHECK_FUNC(_GE, >=)
DEFINE_CHECK_FUNC(_EQ, ==)
DEFINE_CHECK_FUNC(_NE, !=)
// Always-on checking
#define CHECK(x) \
if (!(x)) \
dmlc::LogMessageFatal(__FILE__, __LINE__).stream() \
<< "Check failed: " #x << ' '
#define CHECK_LT(x, y) CHECK_BINARY_OP(_LT, <, x, y)
#define CHECK_GT(x, y) CHECK_BINARY_OP(_GT, >, x, y)
#define CHECK_LE(x, y) CHECK_BINARY_OP(_LE, <=, x, y)
#define CHECK_GE(x, y) CHECK_BINARY_OP(_GE, >=, x, y)
#define CHECK_EQ(x, y) CHECK_BINARY_OP(_EQ, ==, x, y)
#define CHECK_NE(x, y) CHECK_BINARY_OP(_NE, !=, x, y)
#define CHECK_NOTNULL(x) \
((x) == NULL ? dmlc::LogMessageFatal(__FILE__, __LINE__).stream() << "Check notnull: " #x << ' ', (x) : (x)) // NOLINT(*)
// Debug-only checking.
#ifdef NDEBUG
#define DCHECK(x) \
while (false) CHECK(x)
#define DCHECK_LT(x, y) \
while (false) CHECK((x) < (y))
#define DCHECK_GT(x, y) \
while (false) CHECK((x) > (y))
#define DCHECK_LE(x, y) \
while (false) CHECK((x) <= (y))
#define DCHECK_GE(x, y) \
while (false) CHECK((x) >= (y))
#define DCHECK_EQ(x, y) \
while (false) CHECK((x) == (y))
#define DCHECK_NE(x, y) \
while (false) CHECK((x) != (y))
#else
#define DCHECK(x) CHECK(x)
#define DCHECK_LT(x, y) CHECK((x) < (y))
#define DCHECK_GT(x, y) CHECK((x) > (y))
#define DCHECK_LE(x, y) CHECK((x) <= (y))
#define DCHECK_GE(x, y) CHECK((x) >= (y))
#define DCHECK_EQ(x, y) CHECK((x) == (y))
#define DCHECK_NE(x, y) CHECK((x) != (y))
#endif // NDEBUG
#if DMLC_LOG_CUSTOMIZE
#define LOG_INFO dmlc::CustomLogMessage(__FILE__, __LINE__)
#else
#define LOG_INFO dmlc::LogMessage(__FILE__, __LINE__)
#endif
#define LOG_ERROR LOG_INFO
#define LOG_WARNING LOG_INFO
#define LOG_FATAL dmlc::LogMessageFatal(__FILE__, __LINE__)
#define LOG_QFATAL LOG_FATAL
// Poor man version of VLOG
#define VLOG(x) LOG_INFO.stream()
#define LOG(severity) LOG_##severity.stream()
#define LG LOG_INFO.stream()
#define LOG_IF(severity, condition) \
!(condition) ? (void)0 : dmlc::LogMessageVoidify() & LOG(severity)
#ifdef NDEBUG
#define LOG_DFATAL LOG_ERROR
#define DFATAL ERROR
#define DLOG(severity) true ? (void)0 : dmlc::LogMessageVoidify() & LOG(severity)
#define DLOG_IF(severity, condition) \
(true || !(condition)) ? (void)0 : dmlc::LogMessageVoidify() & LOG(severity)
#else
#define LOG_DFATAL LOG_FATAL
#define DFATAL FATAL
#define DLOG(severity) LOG(severity)
#define DLOG_IF(severity, condition) LOG_IF(severity, condition)
#endif
// Poor man version of LOG_EVERY_N
#define LOG_EVERY_N(severity, n) LOG(severity)
class DateLogger {
public:
DateLogger() {
#if defined(_MSC_VER)
_tzset();
#endif
}
const char* HumanDate() {
#if defined(_MSC_VER)
_strtime_s(buffer_, sizeof(buffer_));
#else
time_t time_value = time(NULL);
struct tm *pnow;
#if !defined(_WIN32)
struct tm now;
pnow = localtime_r(&time_value, &now);
#else
pnow = localtime(&time_value); // NOLINT(*)
#endif
snprintf(buffer_, sizeof(buffer_), "%02d:%02d:%02d",
pnow->tm_hour, pnow->tm_min, pnow->tm_sec);
#endif
return buffer_;
}
private:
char buffer_[9];
};
class LogMessage {
public:
LogMessage(const char* file, int line)
:
#ifdef __ANDROID__
log_stream_(std::cout)
#else
log_stream_(std::cerr)
#endif
{
log_stream_ << "[" << pretty_date_.HumanDate() << "] " << file << ":"
<< line << ": ";
}
~LogMessage() { log_stream_ << '\n'; }
std::ostream& stream() { return log_stream_; }
protected:
std::ostream& log_stream_;
private:
DateLogger pretty_date_;
LogMessage(const LogMessage&);
void operator=(const LogMessage&);
};
// customized logger that can allow user to define where to log the message.
class CustomLogMessage {
public:
CustomLogMessage(const char* file, int line) {
log_stream_ << "[" << DateLogger().HumanDate() << "] " << file << ":"
<< line << ": ";
}
~CustomLogMessage() {
Log(log_stream_.str());
}
std::ostream& stream() { return log_stream_; }
/*!
* \brief customized logging of the message.
* This function won't be implemented by libdmlc
* \param msg The message to be logged.
*/
static void Log(const std::string& msg);
private:
std::ostringstream log_stream_;
};
#if DMLC_LOG_FATAL_THROW == 0
class LogMessageFatal : public LogMessage {
public:
LogMessageFatal(const char* file, int line) : LogMessage(file, line) {}
~LogMessageFatal() {
#if DMLC_LOG_STACK_TRACE
const int MAX_STACK_SIZE = 256;
void *stack[MAX_STACK_SIZE];
int nframes = backtrace(stack, MAX_STACK_SIZE);
log_stream_ << "\n\n" << "Stack trace returned " << nframes << " entries:\n";
char **msgs = backtrace_symbols(stack, nframes);
if (msgs != nullptr) {
for (int i = 0; i < nframes; ++i) {
log_stream_ << "[bt] (" << i << ") " << msgs[i] << "\n";
}
}
#endif
log_stream_ << "\n";
abort();
}
private:
LogMessageFatal(const LogMessageFatal&);
void operator=(const LogMessageFatal&);
};
#else
class LogMessageFatal {
public:
LogMessageFatal(const char* file, int line) {
log_stream_ << "[" << pretty_date_.HumanDate() << "] " << file << ":"
<< line << ": ";
}
std::ostringstream &stream() { return log_stream_; }
~LogMessageFatal() DMLC_THROW_EXCEPTION {
#if DMLC_LOG_STACK_TRACE
const int MAX_STACK_SIZE = 256;
void *stack[MAX_STACK_SIZE];
int nframes = backtrace(stack, MAX_STACK_SIZE);
log_stream_ << "\n\n" << "Stack trace returned " << nframes << " entries:\n";
char **msgs = backtrace_symbols(stack, nframes);
if (msgs != nullptr) {
for (int i = 0; i < nframes; ++i) {
log_stream_ << "[bt] (" << i << ") " << msgs[i] << "\n";
}
}
#endif
// throwing out of destructor is evil
// hopefully we can do it here
// also log the message before throw
#if DMLC_LOG_BEFORE_THROW
LOG(ERROR) << log_stream_.str();
#endif
throw Error(log_stream_.str());
}
private:
std::ostringstream log_stream_;
DateLogger pretty_date_;
LogMessageFatal(const LogMessageFatal&);
void operator=(const LogMessageFatal&);
};
#endif
// This class is used to explicitly ignore values in the conditional
// logging macros. This avoids compiler warnings like "value computed
// is not used" and "statement has no effect".
class LogMessageVoidify {
public:
LogMessageVoidify() {}
// This has to be an operator with a precedence lower than << but
// higher than "?:". See its usage.
void operator&(std::ostream&) {}
};
} // namespace dmlc
#endif
#endif // DMLC_LOGGING_H_
//===== EXPANDED: ../dmlc-core/include/dmlc/logging.h =====
// include uint64_t only to make io standalone
#ifdef _MSC_VER
/*! \brief uint64 */
typedef unsigned __int64 uint64_t;
#else
#endif
/*! \brief namespace for dmlc */
namespace dmlc {
/*!
* \brief interface of stream I/O for serialization
*/
class Stream { // NOLINT(*)
public:
/*!
* \brief reads data from a stream
* \param ptr pointer to a memory buffer
* \param size block size
* \return the size of data read
*/
virtual size_t Read(void *ptr, size_t size) = 0;
/*!
* \brief writes data to a stream
* \param ptr pointer to a memory buffer
* \param size block size
*/
virtual void Write(const void *ptr, size_t size) = 0;
/*! \brief virtual destructor */
virtual ~Stream(void) {}
/*!
* \brief generic factory function
* create an stream, the stream will close the underlying files upon deletion
*
* \param uri the uri of the input currently we support
* hdfs://, s3://, and file:// by default file:// will be used
* \param flag can be "w", "r", "a"
* \param allow_null whether NULL can be returned, or directly report error
* \return the created stream, can be NULL when allow_null == true and file do not exist
*/
static Stream *Create(const char *uri,
const char* const flag,
bool allow_null = false);
// helper functions to write/read different data structures
/*!
* \brief writes a data to stream
*
* dmlc::Stream support Write/Read of most STL
* composites and base types.
* If the data type is not supported, a compile time error will
* be issued.
*
* \param data data to be written
* \tparam T the data type to be written
*/
template<typename T>
inline void Write(const T &data);
/*!
* \brief loads a data from stream.
*
* dmlc::Stream support Write/Read of most STL
* composites and base types.
* If the data type is not supported, a compile time error will
* be issued.
*
* \param out_data place holder of data to be deserialized
* \return whether the load was successful
*/
template<typename T>
inline bool Read(T *out_data);
};
/*! \brief interface of i/o stream that support seek */
class SeekStream: public Stream {
public:
// virtual destructor
virtual ~SeekStream(void) {}
/*! \brief seek to certain position of the file */
virtual void Seek(size_t pos) = 0;
/*! \brief tell the position of the stream */
virtual size_t Tell(void) = 0;
/*!
* \brief generic factory function
* create an SeekStream for read only,
* the stream will close the underlying files upon deletion
* error will be reported and the system will exit when create failed
* \param uri the uri of the input currently we support
* hdfs://, s3://, and file:// by default file:// will be used
* \param allow_null whether NULL can be returned, or directly report error
* \return the created stream, can be NULL when allow_null == true and file do not exist
*/
static SeekStream *CreateForRead(const char *uri,
bool allow_null = false);
};
/*! \brief interface for serializable objects */
class Serializable {
public:
/*! \brief virtual destructor */
virtual ~Serializable() {}
/*!
* \brief load the model from a stream
* \param fi stream where to load the model from
*/
virtual void Load(Stream *fi) = 0;
/*!
* \brief saves the model to a stream
* \param fo stream where to save the model to
*/
virtual void Save(Stream *fo) const = 0;
};
/*!
* \brief input split creates that allows reading
* of records from split of data,
* independent part that covers all the dataset
*
* see InputSplit::Create for definition of record
*/
class InputSplit {
public:
/*! \brief a blob of memory region */
struct Blob {
/*! \brief points to start of the memory region */
void *dptr;
/*! \brief size of the memory region */
size_t size;
};
/*!
* \brief hint the inputsplit how large the chunk size
* it should return when implementing NextChunk
* this is a hint so may not be enforced,
* but InputSplit will try adjust its internal buffer
* size to the hinted value
* \param chunk_size the chunk size
*/
virtual void HintChunkSize(size_t chunk_size) {}
/*! \brief get the total size of the InputSplit */
virtual size_t GetTotalSize(void) = 0;
/*! \brief reset the position of InputSplit to beginning */
virtual void BeforeFirst(void) = 0;
/*!
* \brief get the next record, the returning value
* is valid until next call to NextRecord or NextChunk
* caller can modify the memory content of out_rec
*
* For text, out_rec contains a single line
* For recordio, out_rec contains one record content(with header striped)
*
* \param out_rec used to store the result
* \return true if we can successfully get next record
* false if we reached end of split
* \sa InputSplit::Create for definition of record
*/
virtual bool NextRecord(Blob *out_rec) = 0;
/*!
* \brief get a chunk of memory that can contain multiple records,
* the caller needs to parse the content of the resulting chunk,
* for text file, out_chunk can contain data of multiple lines
* for recordio, out_chunk can contain multiple records(including headers)
*
* This function ensures there won't be partial record in the chunk
* caller can modify the memory content of out_chunk,
* the memory is valid until next call to NextRecord or NextChunk
*
* Usually NextRecord is sufficient, NextChunk can be used by some
* multi-threaded parsers to parse the input content
*
* \param out_chunk used to store the result
* \return true if we can successfully get next record
* false if we reached end of split
* \sa InputSplit::Create for definition of record
* \sa RecordIOChunkReader to parse recordio content from out_chunk
*/
virtual bool NextChunk(Blob *out_chunk) = 0;
/*! \brief destructor*/
virtual ~InputSplit(void) {}
/*!
* \brief reset the Input split to a certain part id,
* The InputSplit will be pointed to the head of the new specified segment.
* This feature may not be supported by every implementation of InputSplit.
* \param part_index The part id of the new input.
* \param num_parts The total number of parts.
*/
virtual void ResetPartition(unsigned part_index, unsigned num_parts) = 0;
/*!
* \brief factory function:
* create input split given a uri
* \param uri the uri of the input, can contain hdfs prefix
* \param part_index the part id of current input
* \param num_parts total number of splits
* \param type type of record
* List of possible types: "text", "recordio"
* - "text":
* text file, each line is treated as a record
* input split will split on '\\n' or '\\r'
* - "recordio":
* binary recordio file, see recordio.h
* \return a new input split
* \sa InputSplit::Type
*/
static InputSplit* Create(const char *uri,
unsigned part_index,
unsigned num_parts,
const char *type);
};
/*!
* \brief a std::ostream class that can can wrap Stream objects,
* can use ostream with that output to underlying Stream
*
* Usage example:
* \code
*
* Stream *fs = Stream::Create("hdfs:///test.txt", "w");
* dmlc::ostream os(fs);
* os << "hello world" << std::endl;
* delete fs;
* \endcode
*/
class ostream : public std::basic_ostream<char> {
public:
/*!
* \brief construct std::ostream type
* \param stream the Stream output to be used
* \param buffer_size internal streambuf size
*/
explicit ostream(Stream *stream,
size_t buffer_size = (1 << 10))
: std::basic_ostream<char>(NULL), buf_(buffer_size) {
this->set_stream(stream);
}
// explictly synchronize the buffer
virtual ~ostream() DMLC_NO_EXCEPTION {
buf_.pubsync();
}
/*!
* \brief set internal stream to be stream, reset states
* \param stream new stream as output
*/
inline void set_stream(Stream *stream) {
buf_.set_stream(stream);
this->rdbuf(&buf_);
}
/*! \return how many bytes we written so far */
inline size_t bytes_written(void) const {
return buf_.bytes_out();
}
private:
// internal streambuf
class OutBuf : public std::streambuf {
public:
explicit OutBuf(size_t buffer_size)
: stream_(NULL), buffer_(buffer_size), bytes_out_(0) {
if (buffer_size == 0) buffer_.resize(2);
}
// set stream to the buffer
inline void set_stream(Stream *stream);
inline size_t bytes_out() const { return bytes_out_; }
private:
/*! \brief internal stream by StreamBuf */
Stream *stream_;
/*! \brief internal buffer */
std::vector<char> buffer_;
/*! \brief number of bytes written so far */
size_t bytes_out_;
// override sync
inline int_type sync(void);
// override overflow
inline int_type overflow(int c);
};
/*! \brief buffer of the stream */
OutBuf buf_;
};
/*!
* \brief a std::istream class that can can wrap Stream objects,
* can use istream with that output to underlying Stream
*
* Usage example:
* \code
*
* Stream *fs = Stream::Create("hdfs:///test.txt", "r");
* dmlc::istream is(fs);
* is >> mydata;
* delete fs;
* \endcode
*/
class istream : public std::basic_istream<char> {
public:
/*!
* \brief construct std::ostream type
* \param stream the Stream output to be used
* \param buffer_size internal buffer size
*/
explicit istream(Stream *stream,
size_t buffer_size = (1 << 10))
: std::basic_istream<char>(NULL), buf_(buffer_size) {
this->set_stream(stream);
}
virtual ~istream() DMLC_NO_EXCEPTION {}
/*!
* \brief set internal stream to be stream, reset states
* \param stream new stream as output
*/
inline void set_stream(Stream *stream) {
buf_.set_stream(stream);
this->rdbuf(&buf_);
}
/*! \return how many bytes we read so far */
inline size_t bytes_read(void) const {
return buf_.bytes_read();
}
private:
// internal streambuf
class InBuf : public std::streambuf {
public:
explicit InBuf(size_t buffer_size)
: stream_(NULL), bytes_read_(0),
buffer_(buffer_size) {
if (buffer_size == 0) buffer_.resize(2);
}
// set stream to the buffer
inline void set_stream(Stream *stream);
// return how many bytes read so far
inline size_t bytes_read(void) const {
return bytes_read_;
}
private:
/*! \brief internal stream by StreamBuf */
Stream *stream_;
/*! \brief how many bytes we read so far */
size_t bytes_read_;
/*! \brief internal buffer */
std::vector<char> buffer_;
// override underflow
inline int_type underflow();
};
/*! \brief input buffer */
InBuf buf_;
};
} // namespace dmlc
//===== EXPANDING: ../dmlc-core/include/dmlc/serializer.h =====
/*!
* Copyright (c) 2015 by Contributors
* \file serializer.h
* \brief serializer template class that helps serialization.
* This file do not need to be directly used by most user.
*/
#ifndef DMLC_SERIALIZER_H_
#define DMLC_SERIALIZER_H_
//===== EXPANDING: ../dmlc-core/include/dmlc/type_traits.h =====
/*!
* Copyright (c) 2015 by Contributors
* \file type_traits.h
* \brief type traits information header
*/
#ifndef DMLC_TYPE_TRAITS_H_
#define DMLC_TYPE_TRAITS_H_
#if DMLC_USE_CXX11
#endif
namespace dmlc {
/*!
* \brief whether a type is pod type
* \tparam T the type to query
*/
template<typename T>
struct is_pod {
#if DMLC_USE_CXX11
/*! \brief the value of the traits */
static const bool value = std::is_pod<T>::value;
#else
/*! \brief the value of the traits */
static const bool value = false;
#endif
};
/*!
* \brief whether a type is integer type
* \tparam T the type to query
*/
template<typename T>
struct is_integral {
#if DMLC_USE_CXX11
/*! \brief the value of the traits */
static const bool value = std::is_integral<T>::value;
#else
/*! \brief the value of the traits */
static const bool value = false;
#endif
};
/*!
* \brief whether a type is floating point type
* \tparam T the type to query
*/
template<typename T>
struct is_floating_point {
#if DMLC_USE_CXX11
/*! \brief the value of the traits */
static const bool value = std::is_floating_point<T>::value;
#else
/*! \brief the value of the traits */
static const bool value = false;
#endif
};
/*!
* \brief whether a type is arithemetic type
* \tparam T the type to query
*/
template<typename T>
struct is_arithmetic {
#if DMLC_USE_CXX11
/*! \brief the value of the traits */
static const bool value = std::is_arithmetic<T>::value;
#else
/*! \brief the value of the traits */
static const bool value = (dmlc::is_integral<T>::value ||
dmlc::is_floating_point<T>::value);
#endif
};
/*!
* \brief the string representation of type name
* \tparam T the type to query
* \return a const string of typename.
*/
template<typename T>
inline const char* type_name() {
return "";
}
/*!
* \brief whether a type have save/load function
* \tparam T the type to query
*/
template<typename T>
struct has_saveload {
/*! \brief the value of the traits */
static const bool value = false;
};
/*!
* \brief template to select type based on condition
* For example, IfThenElseType<true, int, float>::Type will give int
* \tparam cond the condition
* \tparam Then the typename to be returned if cond is true
* \tparam The typename to be returned if cond is false
*/
template<bool cond, typename Then, typename Else>
struct IfThenElseType;
/*! \brief macro to quickly declare traits information */
#define DMLC_DECLARE_TRAITS(Trait, Type, Value) \
template<> \
struct Trait<Type> { \
static const bool value = Value; \
}
/*! \brief macro to quickly declare traits information */
#define DMLC_DECLARE_TYPE_NAME(Type, Name) \
template<> \
inline const char* type_name<Type>() { \
return Name; \
}
//! \cond Doxygen_Suppress
// declare special traits when C++11 is not available
#if DMLC_USE_CXX11 == 0
DMLC_DECLARE_TRAITS(is_pod, char, true);
DMLC_DECLARE_TRAITS(is_pod, int8_t, true);
DMLC_DECLARE_TRAITS(is_pod, int16_t, true);
DMLC_DECLARE_TRAITS(is_pod, int32_t, true);
DMLC_DECLARE_TRAITS(is_pod, int64_t, true);
DMLC_DECLARE_TRAITS(is_pod, uint8_t, true);
DMLC_DECLARE_TRAITS(is_pod, uint16_t, true);
DMLC_DECLARE_TRAITS(is_pod, uint32_t, true);
DMLC_DECLARE_TRAITS(is_pod, uint64_t, true);
DMLC_DECLARE_TRAITS(is_pod, float, true);
DMLC_DECLARE_TRAITS(is_pod, double, true);
DMLC_DECLARE_TRAITS(is_integral, char, true);
DMLC_DECLARE_TRAITS(is_integral, int8_t, true);
DMLC_DECLARE_TRAITS(is_integral, int16_t, true);
DMLC_DECLARE_TRAITS(is_integral, int32_t, true);
DMLC_DECLARE_TRAITS(is_integral, int64_t, true);
DMLC_DECLARE_TRAITS(is_integral, uint8_t, true);
DMLC_DECLARE_TRAITS(is_integral, uint16_t, true);
DMLC_DECLARE_TRAITS(is_integral, uint32_t, true);
DMLC_DECLARE_TRAITS(is_integral, uint64_t, true);
DMLC_DECLARE_TRAITS(is_floating_point, float, true);
DMLC_DECLARE_TRAITS(is_floating_point, double, true);
#endif
DMLC_DECLARE_TYPE_NAME(float, "float");
DMLC_DECLARE_TYPE_NAME(double, "double");
DMLC_DECLARE_TYPE_NAME(int, "int");
DMLC_DECLARE_TYPE_NAME(uint32_t, "int (non-negative)");
DMLC_DECLARE_TYPE_NAME(uint64_t, "long (non-negative)");
DMLC_DECLARE_TYPE_NAME(std::string, "string");
DMLC_DECLARE_TYPE_NAME(bool, "boolean");
template<typename Then, typename Else>
struct IfThenElseType<true, Then, Else> {
typedef Then Type;
};
template<typename Then, typename Else>
struct IfThenElseType<false, Then, Else> {
typedef Else Type;
};
//! \endcond
} // namespace dmlc
#endif // DMLC_TYPE_TRAITS_H_
//===== EXPANDED: ../dmlc-core/include/dmlc/type_traits.h =====
#if DMLC_USE_CXX11
#endif
namespace dmlc {
/*! \brief internal namespace for serializers */
namespace serializer {
/*!
* \brief generic serialization handler
* \tparam T the type to be serialized
*/
template<typename T>
struct Handler;
//! \cond Doxygen_Suppress
/*!
* \brief Serializer that redirect calls by condition
* \tparam cond the condition
* \tparam Then the serializer used for then condition
* \tparam Else the serializer used for else condition
* \tparam Return the type of data the serializer handles
*/
template<bool cond, typename Then, typename Else, typename Return>
struct IfThenElse;
template<typename Then, typename Else, typename T>
struct IfThenElse<true, Then, Else, T> {
inline static void Write(Stream *strm, const T &data) {
Then::Write(strm, data);
}
inline static bool Read(Stream *strm, T *data) {
return Then::Read(strm, data);
}
};
template<typename Then, typename Else, typename T>
struct IfThenElse<false, Then, Else, T> {
inline static void Write(Stream *strm, const T &data) {
Else::Write(strm, data);
}
inline static bool Read(Stream *strm, T *data) {
return Else::Read(strm, data);
}
};
/*! \brief Serializer for POD(plain-old-data) data */
template<typename T>
struct PODHandler {
inline static void Write(Stream *strm, const T &data) {
strm->Write(&data, sizeof(T));
}
inline static bool Read(Stream *strm, T *dptr) {
return strm->Read((void*)dptr, sizeof(T)) == sizeof(T); // NOLINT(*)
}
};
// serializer for class that have save/load function
template<typename T>
struct SaveLoadClassHandler {
inline static void Write(Stream *strm, const T &data) {
data.Save(strm);
}
inline static bool Read(Stream *strm, T *data) {
return data->Load(strm);
}
};
/*!
* \brief dummy class for undefined serialization.
* This is used to generate error message when user tries to
* serialize something that is not supported.
* \tparam T the type to be serialized
*/
template<typename T>
struct UndefinedSerializerFor {
};
/*!
* \brief Serializer handler for std::vector<T> where T is POD type.
* \tparam T element type
*/
template<typename T>
struct PODVectorHandler {
inline static void Write(Stream *strm, const std::vector<T> &vec) {
uint64_t sz = static_cast<uint64_t>(vec.size());
strm->Write(&sz, sizeof(sz));
if (sz != 0) {
strm->Write(&vec[0], sizeof(T) * vec.size());
}
}
inline static bool Read(Stream *strm, std::vector<T> *out_vec) {
uint64_t sz;
if (strm->Read(&sz, sizeof(sz)) != sizeof(sz)) return false;
size_t size = static_cast<size_t>(sz);
out_vec->resize(size);
if (sz != 0) {
size_t nbytes = sizeof(T) * size;
return strm->Read(&(*out_vec)[0], nbytes) == nbytes;
}
return true;
}
};
/*!
* \brief Serializer handler for std::vector<T> where T can be composed type
* \tparam T element type
*/
template<typename T>
struct ComposeVectorHandler {
inline static void Write(Stream *strm, const std::vector<T> &vec) {
uint64_t sz = static_cast<uint64_t>(vec.size());
strm->Write(&sz, sizeof(sz));
for (size_t i = 0; i < vec.size(); ++i) {
Handler<T>::Write(strm, vec[i]);
}
}
inline static bool Read(Stream *strm, std::vector<T> *out_vec) {
uint64_t sz;
if (strm->Read(&sz, sizeof(sz)) != sizeof(sz)) return false;
size_t size = static_cast<size_t>(sz);
out_vec->resize(size);
for (size_t i = 0; i < size; ++i) {
if (!Handler<T>::Read(strm, &(*out_vec)[i])) return false;
}
return true;
}
};
/*!
* \brief Serializer handler for std::basic_string<T> where T is POD type.
* \tparam T element type
*/
template<typename T>
struct PODStringHandler {
inline static void Write(Stream *strm, const std::basic_string<T> &vec) {
uint64_t sz = static_cast<uint64_t>(vec.length());
strm->Write(&sz, sizeof(sz));
if (sz != 0) {
strm->Write(&vec[0], sizeof(T) * vec.length());
}
}
inline static bool Read(Stream *strm, std::basic_string<T> *out_vec) {
uint64_t sz;
if (strm->Read(&sz, sizeof(sz)) != sizeof(sz)) return false;
size_t size = static_cast<size_t>(sz);
out_vec->resize(size);
if (sz != 0) {
size_t nbytes = sizeof(T) * size;
return strm->Read(&(*out_vec)[0], nbytes) == nbytes;
}
return true;
}
};
/*! \brief Serializer for std::pair */
template<typename TA, typename TB>
struct PairHandler {
inline static void Write(Stream *strm, const std::pair<TA, TB> &data) {
Handler<TA>::Write(strm, data.first);
Handler<TB>::Write(strm, data.second);
}
inline static bool Read(Stream *strm, std::pair<TA, TB> *data) {
return Handler<TA>::Read(strm, &(data->first)) &&
Handler<TB>::Read(strm, &(data->second));
}
};
// set type handler that can handle most collection type case
template<typename ContainerType>
struct CollectionHandler {
inline static void Write(Stream *strm, const ContainerType &data) {
typedef typename ContainerType::value_type ElemType;
// dump data to vector
std::vector<ElemType> vdata(data.begin(), data.end());
// serialize the vector
Handler<std::vector<ElemType> >::Write(strm, vdata);
}
inline static bool Read(Stream *strm, ContainerType *data) {
typedef typename ContainerType::value_type ElemType;
std::vector<ElemType> vdata;
if (!Handler<std::vector<ElemType> >::Read(strm, &vdata)) return false;
data->clear();
data->insert(vdata.begin(), vdata.end());
return true;
}
};
// handler that can handle most list type case
// this type insert function takes additional iterator
template<typename ListType>
struct ListHandler {
inline static void Write(Stream *strm, const ListType &data) {
typedef typename ListType::value_type ElemType;
// dump data to vector
std::vector<ElemType> vdata(data.begin(), data.end());
// serialize the vector
Handler<std::vector<ElemType> >::Write(strm, vdata);
}
inline static bool Read(Stream *strm, ListType *data) {
typedef typename ListType::value_type ElemType;
std::vector<ElemType> vdata;
if (!Handler<std::vector<ElemType> >::Read(strm, &vdata)) return false;
data->clear();
data->insert(data->begin(), vdata.begin(), vdata.end());
return true;
}
};
//! \endcond
/*!
* \brief generic serialization handler for type T
*
* User can define specialization of this class to support
* composite serialization of their own class.
*
* \tparam T the type to be serialized
*/
template<typename T>
struct Handler {
/*!
* \brief write data to stream
* \param strm the stream we write the data.
* \param data the data obeject to be serialized
*/
inline static void Write(Stream *strm, const T &data) {
IfThenElse<dmlc::is_pod<T>::value,
PODHandler<T>,
IfThenElse<dmlc::has_saveload<T>::value,
SaveLoadClassHandler<T>,
UndefinedSerializerFor<T>, T>,
T>
::Write(strm, data);
}
/*!
* \brief read data to stream
* \param strm the stream to read the data.
* \param data the pointer to the data obeject to read
* \return whether the read is successful
*/
inline static bool Read(Stream *strm, T *data) {
return IfThenElse<dmlc::is_pod<T>::value,
PODHandler<T>,
IfThenElse<dmlc::has_saveload<T>::value,
SaveLoadClassHandler<T>,
UndefinedSerializerFor<T>, T>,
T>
::Read(strm, data);
}
};
//! \cond Doxygen_Suppress
template<typename T>
struct Handler<std::vector<T> > {
inline static void Write(Stream *strm, const std::vector<T> &data) {
IfThenElse<dmlc::is_pod<T>::value,
PODVectorHandler<T>,
ComposeVectorHandler<T>, std::vector<T> >
::Write(strm, data);
}
inline static bool Read(Stream *strm, std::vector<T> *data) {
return IfThenElse<dmlc::is_pod<T>::value,
PODVectorHandler<T>,
ComposeVectorHandler<T>,
std::vector<T> >
::Read(strm, data);
}
};
template<typename T>
struct Handler<std::basic_string<T> > {
inline static void Write(Stream *strm, const std::basic_string<T> &data) {
IfThenElse<dmlc::is_pod<T>::value,
PODStringHandler<T>,
UndefinedSerializerFor<T>,
std::basic_string<T> >
::Write(strm, data);
}
inline static bool Read(Stream *strm, std::basic_string<T> *data) {
return IfThenElse<dmlc::is_pod<T>::value,
PODStringHandler<T>,
UndefinedSerializerFor<T>,
std::basic_string<T> >
::Read(strm, data);
}
};
template<typename TA, typename TB>
struct Handler<std::pair<TA, TB> > {
inline static void Write(Stream *strm, const std::pair<TA, TB> &data) {
IfThenElse<dmlc::is_pod<TA>::value && dmlc::is_pod<TB>::value,
PODHandler<std::pair<TA, TB> >,
PairHandler<TA, TB>,
std::pair<TA, TB> >
::Write(strm, data);
}
inline static bool Read(Stream *strm, std::pair<TA, TB> *data) {
return IfThenElse<dmlc::is_pod<TA>::value && dmlc::is_pod<TB>::value,
PODHandler<std::pair<TA, TB> >,
PairHandler<TA, TB>,
std::pair<TA, TB> >
::Read(strm, data);
}
};
template<typename K, typename V>
struct Handler<std::map<K, V> >
: public CollectionHandler<std::map<K, V> > {
};
template<typename K, typename V>
struct Handler<std::multimap<K, V> >
: public CollectionHandler<std::multimap<K, V> > {
};
template<typename T>
struct Handler<std::set<T> >
: public CollectionHandler<std::set<T> > {
};
template<typename T>
struct Handler<std::multiset<T> >
: public CollectionHandler<std::multiset<T> > {
};
template<typename T>
struct Handler<std::list<T> >
: public ListHandler<std::list<T> > {
};
template<typename T>
struct Handler<std::deque<T> >
: public ListHandler<std::deque<T> > {
};
#if DMLC_USE_CXX11
template<typename K, typename V>
struct Handler<std::unordered_map<K, V> >
: public CollectionHandler<std::unordered_map<K, V> > {
};
template<typename K, typename V>
struct Handler<std::unordered_multimap<K, V> >
: public CollectionHandler<std::unordered_multimap<K, V> > {
};
template<typename T>
struct Handler<std::unordered_set<T> >
: public CollectionHandler<std::unordered_set<T> > {
};
template<typename T>
struct Handler<std::unordered_multiset<T> >
: public CollectionHandler<std::unordered_multiset<T> > {
};
#endif
//! \endcond
} // namespace serializer
} // namespace dmlc
#endif // DMLC_SERIALIZER_H_
//===== EXPANDED: ../dmlc-core/include/dmlc/serializer.h =====
namespace dmlc {
// implementations of inline functions
template<typename T>
inline void Stream::Write(const T &data) {
serializer::Handler<T>::Write(this, data);
}
template<typename T>
inline bool Stream::Read(T *out_data) {
return serializer::Handler<T>::Read(this, out_data);
}
// implementations for ostream
inline void ostream::OutBuf::set_stream(Stream *stream) {
if (stream_ != NULL) this->pubsync();
this->stream_ = stream;
this->setp(&buffer_[0], &buffer_[0] + buffer_.size() - 1);
}
inline int ostream::OutBuf::sync(void) {
if (stream_ == NULL) return -1;
std::ptrdiff_t n = pptr() - pbase();
stream_->Write(pbase(), n);
this->pbump(-static_cast<int>(n));
bytes_out_ += n;
return 0;
}
inline int ostream::OutBuf::overflow(int c) {
*(this->pptr()) = c;
std::ptrdiff_t n = pptr() - pbase();
this->pbump(-static_cast<int>(n));
if (c == EOF) {
stream_->Write(pbase(), n);
bytes_out_ += n;
} else {
stream_->Write(pbase(), n + 1);
bytes_out_ += n + 1;
}
return c;
}
// implementations for istream
inline void istream::InBuf::set_stream(Stream *stream) {
stream_ = stream;
this->setg(&buffer_[0], &buffer_[0], &buffer_[0]);
}
inline int istream::InBuf::underflow() {
char *bhead = &buffer_[0];
if (this->gptr() == this->egptr()) {
size_t sz = stream_->Read(bhead, buffer_.size());
this->setg(bhead, bhead, bhead + sz);
bytes_read_ += sz;
}
if (this->gptr() == this->egptr()) {
return traits_type::eof();
} else {
return traits_type::to_int_type(*gptr());
}
}
} // namespace dmlc
#endif // DMLC_IO_H_
//===== EXPANDED: ../dmlc-core/include/dmlc/io.h =====
//===== EXPANDING: ../dmlc-core/src/io/line_split.h =====
/*!
* Copyright (c) 2015 by Contributors
* \file line_split.h
* \brief base class implementation of input splitter
* \author Tianqi Chen
*/
#ifndef DMLC_IO_LINE_SPLIT_H_
#define DMLC_IO_LINE_SPLIT_H_
//===== EXPANDING: ../dmlc-core/src/io/input_split_base.h =====
/*!
* Copyright (c) 2015 by Contributors
* \file input_split_base.h
* \brief base class to construct input split from multiple files
* \author Tianqi Chen
*/
#ifndef DMLC_IO_INPUT_SPLIT_BASE_H_
#define DMLC_IO_INPUT_SPLIT_BASE_H_
//===== EXPANDING: ../dmlc-core/src/io/filesys.h =====
/*!
* Copyright (c) 2015 by Contributors
* \file filesystem.h
* \brief general file system io interface
* \author Tianqi Chen
*/
#ifndef DMLC_IO_FILESYS_H_
#define DMLC_IO_FILESYS_H_
namespace dmlc {
namespace io {
/*! \brief common data structure for URI */
struct URI {
/*! \brief protocol */
std::string protocol;
/*!
* \brief host name, namenode for HDFS, bucket name for s3
*/
std::string host;
/*! \brief name of the path */
std::string name;
/*! \brief enable default constructor */
URI(void) {}
/*!
* \brief construct from URI string
*/
explicit URI(const char *uri) {
const char *p = std::strstr(uri, "://");
if (p == NULL) {
name = uri;
} else {
protocol = std::string(uri, p - uri + 3);
uri = p + 3;
p = std::strchr(uri, '/');
if (p == NULL) {
host = uri; name = '/';
} else {
host = std::string(uri, p - uri);
name = p;
}
}
}
/*! \brief string representation */
inline std::string str(void) const {
return protocol + host + name;
}
};
/*! \brief type of file */
enum FileType {
/*! \brief the file is file */
kFile,
/*! \brief the file is directory */
kDirectory
};
/*! \brief use to store file information */
struct FileInfo {
/*! \brief full path to the file */
URI path;
/*! \brief the size of the file */
size_t size;
/*! \brief the type of the file */
FileType type;
/*! \brief default constructor */
FileInfo() : size(0), type(kFile) {}
};
/*! \brief file system system interface */
class FileSystem {
public:
/*!
* \brief get singleton of filesystem instance according to URI
* \param path can be s3://..., hdfs://..., file://...,
* empty string(will return local)
* \return a corresponding filesystem, report error if
* we cannot find a matching system
*/
static FileSystem *GetInstance(const URI &path);
/*! \brief virtual destructor */
virtual ~FileSystem() {}
/*!
* \brief get information about a path
* \param path the path to the file
* \return the information about the file
*/
virtual FileInfo GetPathInfo(const URI &path) = 0;
/*!
* \brief list files in a directory
* \param path to the file
* \param out_list the output information about the files
*/
virtual void ListDirectory(const URI &path, std::vector<FileInfo> *out_list) = 0;
/*!
* \brief open a stream
* \param path path to file
* \param uri the uri of the input, can contain hdfs prefix
* \param flag can be "w", "r", "a
* \param allow_null whether NULL can be returned, or directly report error
* \return the created stream, can be NULL when allow_null == true and file do not exist
*/
virtual Stream *Open(const URI &path,
const char* const flag,
bool allow_null = false) = 0;
/*!
* \brief open a seekable stream for read
* \param path the path to the file
* \param allow_null whether NULL can be returned, or directly report error
* \return the created stream, can be NULL when allow_null == true and file do not exist
*/
virtual SeekStream *OpenForRead(const URI &path,
bool allow_null = false) = 0;
};
} // namespace io
} // namespace dmlc
#endif // DMLC_IO_FILESYS_H_
//===== EXPANDED: ../dmlc-core/src/io/filesys.h =====
namespace dmlc {
namespace io {
/*! \brief class to construct input split from multiple files */
class InputSplitBase : public InputSplit {
public:
/*!
* \brief helper struct to hold chunk data
* with internal pointer to move along the record
*/
struct Chunk {
char *begin;
char *end;
std::vector<size_t> data;
explicit Chunk(size_t buffer_size)
: begin(NULL), end(NULL),
data(buffer_size + 1) {}
// load chunk from split
bool Load(InputSplitBase *split, size_t buffer_size);
};
// 16 MB
static const size_t kBufferSize = 2UL << 20UL;
// destructor
virtual ~InputSplitBase(void);
// implement BeforeFirst
virtual void BeforeFirst(void);
virtual void HintChunkSize(size_t chunk_size) {
buffer_size_ = std::max(chunk_size / sizeof(size_t), buffer_size_);
}
virtual size_t GetTotalSize(void) {
return file_offset_.back();
}
// implement next record
virtual bool NextRecord(Blob *out_rec) {
while (!ExtractNextRecord(out_rec, &tmp_chunk_)) {
if (!tmp_chunk_.Load(this, buffer_size_)) return false;
}
return true;
}
// implement next chunk
virtual bool NextChunk(Blob *out_chunk) {
while (!ExtractNextChunk(out_chunk, &tmp_chunk_)) {
if (!tmp_chunk_.Load(this, buffer_size_)) return false;
}
return true;
}
// implement ResetPartition.
virtual void ResetPartition(unsigned rank, unsigned nsplit);
/*!
* \brief read a chunk of data into buf
* the data can span multiple records,
* but cannot contain partial records
*
* \param buf the memory region of the buffer,
* should be properly aligned to 64 bits
* \param size the maximum size of memory,
* after the function returns, it stores the size of the chunk
* \return whether end of file was reached
*/
bool ReadChunk(void *buf, size_t *size);
/*!
* \brief extract next chunk from the chunk
* \param out_chunk the output record
* \param chunk the chunk information
* \return true if non-empty record is extracted
* false if the chunk is already finishes its life
*/
bool ExtractNextChunk(Blob *out_rchunk, Chunk *chunk);
/*!
* \brief extract next record from the chunk
* \param out_rec the output record
* \param chunk the chunk information
* \return true if non-empty record is extracted
* false if the chunk is already finishes its life
*/
virtual bool ExtractNextRecord(Blob *out_rec, Chunk *chunk) = 0;
protected:
// constructor
InputSplitBase()
: fs_(NULL),
align_bytes_(8),
tmp_chunk_(kBufferSize),
buffer_size_(kBufferSize) {}
/*!
* \brief intialize the base before doing anything
* \param fs the filesystem ptr
* \param uri the uri of the files
* \param rank the rank of the split
* \param nsplit number of splits
* \param align_bytes the head split must be multiple of align_bytes
* this also checks if file size are multiple of align_bytes
*/
void Init(FileSystem *fs,
const char *uri,
size_t align_bytes);
// to be implemented by child class
/*!
* \brief seek to the beginning of the first record
* in current file pointer
* \return how many bytes we read past
*/
virtual size_t SeekRecordBegin(Stream *fi) = 0;
/*!
* \brief find the last occurance of record header
* \param begin beginning of the buffer
* \param end end of the buffer
* \return the pointer between [begin, end] indicating the
* last record head
*/
virtual const char*
FindLastRecordBegin(const char *begin, const char *end) = 0;
private:
/*! \brief FileSystem */
FileSystem *filesys_;
/*! \brief information about files */
std::vector<FileInfo> files_;
/*! \brief current input stream */
SeekStream *fs_;
/*! \brief bytes to be aligned */
size_t align_bytes_;
/*! \brief file pointer of which file to read on */
size_t file_ptr_;
/*! \brief file pointer where the end of file lies */
size_t file_ptr_end_;
/*! \brief get the current offset */
size_t offset_curr_;
/*! \brief beginning of offset */
size_t offset_begin_;
/*! \brief end of the offset */
size_t offset_end_;
/*! \brief temporal chunk */
Chunk tmp_chunk_;
/*! \brief buffer size */
size_t buffer_size_;
/*! \brief byte-offset of each file */
std::vector<size_t> file_offset_;
/*! \brief internal overflow buffer */
std::string overflow_;
/*! \brief initialize information in files */
void InitInputFileInfo(const std::string& uri);
/*! \brief strip continous chars in the end of str */
std::string StripEnd(std::string str, char ch);
/*! \brief same as stream.Read */
size_t Read(void *ptr, size_t size);
};
} // namespace io
} // namespace dmlc
#endif // DMLC_IO_INPUT_SPLIT_BASE_H_
//===== EXPANDED: ../dmlc-core/src/io/input_split_base.h =====
namespace dmlc {
namespace io {
/*! \brief class that split the files by line */
class LineSplitter : public InputSplitBase {
public:
LineSplitter(FileSystem *fs,
const char *uri,
unsigned rank,
unsigned nsplit) {
this->Init(fs, uri, 1);
this->ResetPartition(rank, nsplit);
}
virtual bool ExtractNextRecord(Blob *out_rec, Chunk *chunk);
protected:
virtual size_t SeekRecordBegin(Stream *fi);
virtual const char*
FindLastRecordBegin(const char *begin, const char *end);
};
} // namespace io
} // namespace dmlc
#endif // DMLC_IO_LINE_SPLIT_H_
//===== EXPANDED: ../dmlc-core/src/io/line_split.h =====
namespace dmlc {
namespace io {
size_t LineSplitter::SeekRecordBegin(Stream *fi) {
char c = '\0';
size_t nstep = 0;
// search till fist end-of-line
while (true) {
if (fi->Read(&c, sizeof(c)) == 0) return nstep;
nstep += 1;
if (c == '\n' || c == '\r') break;
}
// search until first non-endofline
while (true) {
if (fi->Read(&c, sizeof(c)) == 0) return nstep;
if (c != '\n' && c != '\r') break;
// non-end-of-line should not count
nstep += 1;
}
return nstep;
}
const char* LineSplitter::FindLastRecordBegin(const char *begin,
const char *end) {
CHECK(begin != end);
for (const char *p = end - 1; p != begin; --p) {
if (*p == '\n' || *p == '\r') return p + 1;
}
return begin;
}
bool LineSplitter::ExtractNextRecord(Blob *out_rec, Chunk *chunk) {
if (chunk->begin == chunk->end) return false;
char *p;
for (p = chunk->begin; p != chunk->end; ++p) {
if (*p == '\n' || *p == '\r') break;
}
for (; p != chunk->end; ++p) {
if (*p != '\n' && *p != '\r') break;
}
// set the string end sign for safety
if (p == chunk->end) {
*p = '\0';
} else {
*(p - 1) = '\0';
}
out_rec->dptr = chunk->begin;
out_rec->size = p - chunk->begin;
chunk->begin = p;
return true;
}
} // namespace io
} // namespace dmlc
//===== EXPANDED: ../dmlc-core/src/io/line_split.cc =====
//===== EXPANDING: ../dmlc-core/src/io/recordio_split.cc =====
// Copyright by Contributors
//===== EXPANDING: ../dmlc-core/include/dmlc/recordio.h =====
/*!
* Copyright (c) 2015 by Contributors
* \file recordio.h
* \brief recordio that is able to pack binary data into a splittable
* format, useful to exchange data in binary serialization,
* such as binary raw data or protobuf
*/
#ifndef DMLC_RECORDIO_H_
#define DMLC_RECORDIO_H_
namespace dmlc {
/*!
* \brief writer of binary recordio
* binary format for recordio
* recordio format: magic lrecord data pad
*
* - magic is magic number
* - pad is simply a padding space to make record align to 4 bytes
* - lrecord encodes length and continue bit
* - data.length() = (lrecord & (1U<<29U - 1));
* - cflag == (lrecord >> 29U) & 7;
*
* cflag was used to handle (rare) special case when magic number
* occured in the data sequence.
*
* In such case, the data is splitted into multiple records by
* the cells of magic number
*
* (1) cflag == 0: this is a complete record;
* (2) cflag == 1: start of a multiple-rec;
* cflag == 2: middle of multiple-rec;
* cflag == 3: end of multiple-rec
*/
class RecordIOWriter {
public:
/*!
* \brief magic number of recordio
* note: (kMagic >> 29U) & 7 > 3
* this ensures lrec will not be kMagic
*/
static const uint32_t kMagic = 0xced7230a;
/*!
* \brief encode the lrecord
* \param cflag cflag part of the lrecord
* \param length length part of lrecord
* \return the encoded data
*/
inline static uint32_t EncodeLRec(uint32_t cflag, uint32_t length) {
return (cflag << 29U) | length;
}
/*!
* \brief decode the flag part of lrecord
* \param rec the lrecord
* \return the flag
*/
inline static uint32_t DecodeFlag(uint32_t rec) {
return (rec >> 29U) & 7U;
}
/*!
* \brief decode the length part of lrecord
* \param rec the lrecord
* \return the length
*/
inline static uint32_t DecodeLength(uint32_t rec) {
return rec & ((1U << 29U) - 1U);
}
/*!
* \brief constructor
* \param stream the stream to be constructed
*/
explicit RecordIOWriter(Stream *stream)
: stream_(stream), seek_stream_(dynamic_cast<SeekStream*>(stream)),
except_counter_(0) {
CHECK(sizeof(uint32_t) == 4) << "uint32_t needs to be 4 bytes";
}
/*!
* \brief write record to the stream
* \param buf the buffer of memory region
* \param size the size of record to write out
*/
void WriteRecord(const void *buf, size_t size);
/*!
* \brief write record to the stream
* \param data the data to write out
*/
inline void WriteRecord(const std::string &data) {
this->WriteRecord(data.c_str(), data.length());
}
/*!
* \return number of exceptions(occurance of magic number)
* during the writing process
*/
inline size_t except_counter(void) const {
return except_counter_;
}
/*! \brief tell the current position of the input stream */
inline size_t Tell(void) {
CHECK(seek_stream_ != NULL) << "The input stream is not seekable";
return seek_stream_->Tell();
}
private:
/*! \brief output stream */
Stream *stream_;
/*! \brief seekable stream */
SeekStream *seek_stream_;
/*! \brief counts the number of exceptions */
size_t except_counter_;
};
/*!
* \brief reader of binary recordio to reads in record from stream
* \sa RecordIOWriter
*/
class RecordIOReader {
public:
/*!
* \brief constructor
* \param stream the stream to be constructed
*/
explicit RecordIOReader(Stream *stream)
: stream_(stream), seek_stream_(dynamic_cast<SeekStream*>(stream)),
end_of_stream_(false) {
CHECK(sizeof(uint32_t) == 4) << "uint32_t needs to be 4 bytes";
}
/*!
* \brief read next complete record from stream
* \param out_rec used to store output record in string
* \return true of read was successful, false if end of stream was reached
*/
bool NextRecord(std::string *out_rec);
/*! \brief seek to certain position of the input stream */
inline void Seek(size_t pos) {
CHECK(seek_stream_ != NULL) << "The input stream is not seekable";
seek_stream_->Seek(pos);
}
private:
/*! \brief output stream */
Stream *stream_;
SeekStream *seek_stream_;
/*! \brief whether we are at end of stream */
bool end_of_stream_;
};
/*!
* \brief reader of binary recordio from Blob returned by InputSplit
* This class divides the blob into several independent parts specified by caller,
* and read from one segment.
* The part reading can be used together with InputSplit::NextChunk for
* multi-threaded parsing(each thread take a RecordIOChunkReader)
*
* \sa RecordIOWriter, InputSplit
*/
class RecordIOChunkReader {
public:
/*!
* \brief constructor
* \param chunk source data returned by InputSplit
* \param part_index which part we want to reado
* \param num_parts number of total segments
*/
explicit RecordIOChunkReader(InputSplit::Blob chunk,
unsigned part_index = 0,
unsigned num_parts = 1);
/*!
* \brief read next complete record from stream
* the blob contains the memory content
* NOTE: this function is not threadsafe, use one
* RecordIOChunkReader per thread
* \param out_rec used to store output blob, the header is already
* removed and out_rec only contains the memory content
* \return true of read was successful, false if end was reached
*/
bool NextRecord(InputSplit::Blob *out_rec);
private:
/*! \brief internal temporal data */
std::string temp_;
/*! \brief internal data pointer */
char *pbegin_, *pend_;
};
} // namespace dmlc
#endif // DMLC_RECORDIO_H_
//===== EXPANDED: ../dmlc-core/include/dmlc/recordio.h =====
//===== EXPANDING: ../dmlc-core/src/io/recordio_split.h =====
/*!
* Copyright (c) 2015 by Contributors
* \file recordio_split.h
* \brief input split that splits recordio files
* \author Tianqi Chen
*/
#ifndef DMLC_IO_RECORDIO_SPLIT_H_
#define DMLC_IO_RECORDIO_SPLIT_H_
namespace dmlc {
namespace io {
/*! \brief class that split the files by line */
class RecordIOSplitter : public InputSplitBase {
public:
RecordIOSplitter(FileSystem *fs,
const char *uri,
unsigned rank,
unsigned nsplit) {
this->Init(fs, uri, 4);
this->ResetPartition(rank, nsplit);
}
virtual bool ExtractNextRecord(Blob *out_rec, Chunk *chunk);
protected:
virtual size_t SeekRecordBegin(Stream *fi);
virtual const char*
FindLastRecordBegin(const char *begin, const char *end);
};
} // namespace io
} // namespace dmlc
#endif // DMLC_IO_RECORDIO_SPLIT_H_
//===== EXPANDED: ../dmlc-core/src/io/recordio_split.h =====
namespace dmlc {
namespace io {
size_t RecordIOSplitter::SeekRecordBegin(Stream *fi) {
size_t nstep = 0;
uint32_t v, lrec;
while (true) {
if (fi->Read(&v, sizeof(v)) == 0) return nstep;
nstep += sizeof(v);
if (v == RecordIOWriter::kMagic) {
CHECK(fi->Read(&lrec, sizeof(lrec)) != 0)
<< "invalid record io format";
nstep += sizeof(lrec);
uint32_t cflag = RecordIOWriter::DecodeFlag(lrec);
if (cflag == 0 || cflag == 1) break;
}
}
// should point at head of record
return nstep - 2 * sizeof(uint32_t);
}
const char* RecordIOSplitter::FindLastRecordBegin(const char *begin,
const char *end) {
CHECK_EQ((reinterpret_cast<size_t>(begin) & 3UL), 0U);
CHECK_EQ((reinterpret_cast<size_t>(end) & 3UL), 0U);
const uint32_t *pbegin = reinterpret_cast<const uint32_t *>(begin);
const uint32_t *p = reinterpret_cast<const uint32_t *>(end);
CHECK(p >= pbegin + 2);
for (p = p - 2; p != pbegin; --p) {
if (p[0] == RecordIOWriter::kMagic) {
uint32_t cflag = RecordIOWriter::DecodeFlag(p[1]);
if (cflag == 0 || cflag == 1) {
return reinterpret_cast<const char*>(p);
}
}
}
return begin;
}
bool RecordIOSplitter::ExtractNextRecord(Blob *out_rec, Chunk *chunk) {
if (chunk->begin == chunk->end) return false;
CHECK(chunk->begin + 2 * sizeof(uint32_t) <= chunk->end)
<< "Invalid RecordIO Format";
CHECK_EQ((reinterpret_cast<size_t>(chunk->begin) & 3UL), 0U);
CHECK_EQ((reinterpret_cast<size_t>(chunk->end) & 3UL), 0U);
uint32_t *p = reinterpret_cast<uint32_t *>(chunk->begin);
uint32_t cflag = RecordIOWriter::DecodeFlag(p[1]);
uint32_t clen = RecordIOWriter::DecodeLength(p[1]);
// skip header
out_rec->dptr = chunk->begin + 2 * sizeof(uint32_t);
// move pbegin
chunk->begin += 2 * sizeof(uint32_t) + (((clen + 3U) >> 2U) << 2U);
CHECK(chunk->begin <= chunk->end) << "Invalid RecordIO Format";
out_rec->size = clen;
if (cflag == 0) return true;
const uint32_t kMagic = RecordIOWriter::kMagic;
// abnormal path, move data around to make a full part
CHECK(cflag == 1U) << "Invalid RecordIO Format";
while (cflag != 3U) {
CHECK(chunk->begin + 2 * sizeof(uint32_t) <= chunk->end);
p = reinterpret_cast<uint32_t *>(chunk->begin);
CHECK(p[0] == RecordIOWriter::kMagic);
cflag = RecordIOWriter::DecodeFlag(p[1]);
clen = RecordIOWriter::DecodeLength(p[1]);
// pad kmagic in between
std::memcpy(reinterpret_cast<char*>(out_rec->dptr) + out_rec->size,
&kMagic, sizeof(kMagic));
out_rec->size += sizeof(kMagic);
// move the rest of the blobs
if (clen != 0) {
std::memmove(reinterpret_cast<char*>(out_rec->dptr) + out_rec->size,
chunk->begin + 2 * sizeof(uint32_t), clen);
out_rec->size += clen;
}
chunk->begin += 2 * sizeof(uint32_t) + (((clen + 3U) >> 2U) << 2U);
}
return true;
}
} // namespace io
} // namespace dmlc
//===== EXPANDED: ../dmlc-core/src/io/recordio_split.cc =====
//===== EXPANDING: ../dmlc-core/src/io/input_split_base.cc =====
// Copyright by Contributors
//===== EXPANDING: ../dmlc-core/include/dmlc/common.h =====
/*!
* Copyright (c) 2015 by Contributors
* \file common.h
* \brief defines some common utility function.
*/
#ifndef DMLC_COMMON_H_
#define DMLC_COMMON_H_
namespace dmlc {
/*!
* \brief Split a string by delimiter
* \param s String to be splitted.
* \param delim The delimiter.
* \return a splitted vector of strings.
*/
inline std::vector<std::string> Split(const std::string& s, char delim) {
std::string item;
std::istringstream is(s);
std::vector<std::string> ret;
while (std::getline(is, item, delim)) {
ret.push_back(item);
}
return ret;
}
} // namespace dmlc
#endif // DMLC_COMMON_H_
//===== EXPANDED: ../dmlc-core/include/dmlc/common.h =====
#if DMLC_USE_REGEX
#endif
namespace dmlc {
namespace io {
void InputSplitBase::Init(FileSystem *filesys,
const char *uri,
size_t align_bytes) {
this->filesys_ = filesys;
// initialize the path
this->InitInputFileInfo(uri);
file_offset_.resize(files_.size() + 1);
file_offset_[0] = 0;
for (size_t i = 0; i < files_.size(); ++i) {
file_offset_[i + 1] = file_offset_[i] + files_[i].size;
CHECK(files_[i].size % align_bytes == 0)
<< "file do not align by " << align_bytes << " bytes";
}
this->align_bytes_ = align_bytes;
}
void InputSplitBase::ResetPartition(unsigned rank,
unsigned nsplit) {
size_t ntotal = file_offset_.back();
size_t nstep = (ntotal + nsplit - 1) / nsplit;
// align the nstep to 4 bytes
nstep = ((nstep + align_bytes_ - 1) / align_bytes_) * align_bytes_;
offset_begin_ = std::min(nstep * rank, ntotal);
offset_end_ = std::min(nstep * (rank + 1), ntotal);
offset_curr_ = offset_begin_;
if (offset_begin_ == offset_end_) return;
file_ptr_ = std::upper_bound(file_offset_.begin(),
file_offset_.end(),
offset_begin_) - file_offset_.begin() - 1;
file_ptr_end_ = std::upper_bound(file_offset_.begin(),
file_offset_.end(),
offset_end_) - file_offset_.begin() - 1;
if (fs_ != NULL) {
delete fs_; fs_ = NULL;
}
// find the exact ending position
if (offset_end_ != file_offset_[file_ptr_end_]) {
CHECK(offset_end_ >file_offset_[file_ptr_end_]);
CHECK(file_ptr_end_ < files_.size());
fs_ = filesys_->OpenForRead(files_[file_ptr_end_].path);
fs_->Seek(offset_end_ - file_offset_[file_ptr_end_]);
offset_end_ += SeekRecordBegin(fs_);
delete fs_;
}
fs_ = filesys_->OpenForRead(files_[file_ptr_].path);
if (offset_begin_ != file_offset_[file_ptr_]) {
fs_->Seek(offset_begin_ - file_offset_[file_ptr_]);
offset_begin_ += SeekRecordBegin(fs_);
}
this->BeforeFirst();
}
void InputSplitBase::BeforeFirst(void) {
if (offset_begin_ >= offset_end_) return;
size_t fp = std::upper_bound(file_offset_.begin(),
file_offset_.end(),
offset_begin_) - file_offset_.begin() - 1;
if (file_ptr_ != fp) {
delete fs_;
file_ptr_ = fp;
fs_ = filesys_->OpenForRead(files_[file_ptr_].path);
}
// seek to beginning of stream
fs_->Seek(offset_begin_ - file_offset_[file_ptr_]);
offset_curr_ = offset_begin_;
tmp_chunk_.begin = tmp_chunk_.end = NULL;
// clear overflow buffer
overflow_.clear();
}
InputSplitBase::~InputSplitBase(void) {
delete fs_;
// no need to delete filesystem, it was singleton
}
std::string InputSplitBase::StripEnd(std::string str, char ch) {
while (str.length() != 0 && str[str.length() - 1] == ch) {
str.resize(str.length() - 1);
}
return str;
}
void InputSplitBase::InitInputFileInfo(const std::string& uri) {
// split by :
const char dlm = ';';
std::vector<std::string> file_list = Split(uri, dlm);
std::vector<URI> expanded_list;
// expand by match regex pattern.
for (size_t i = 0; i < file_list.size(); ++i) {
URI path(file_list[i].c_str());
size_t pos = path.name.rfind('/');
if (pos == std::string::npos || pos + 1 == path.name.length()) {
expanded_list.push_back(path);
} else {
URI dir = path;
dir.name = path.name.substr(0, pos);
std::vector<FileInfo> dfiles;
filesys_->ListDirectory(dir, &dfiles);
bool exact_match = false;
for (size_t i = 0; i < dfiles.size(); ++i) {
if (StripEnd(dfiles[i].path.name, '/') == StripEnd(path.name, '/')) {
expanded_list.push_back(dfiles[i].path);
exact_match = true;
break;
}
}
#if DMLC_USE_REGEX
if (!exact_match) {
std::string spattern = path.name;
try {
std::regex pattern(spattern);
for (size_t i = 0; i < dfiles.size(); ++i) {
if (dfiles[i].type != kFile || dfiles[i].size == 0) continue;
std::string stripped = StripEnd(dfiles[i].path.name, '/');
std::smatch base_match;
if (std::regex_match(stripped, base_match, pattern)) {
for (size_t j = 0; j < base_match.size(); ++j) {
if (base_match[j].str() == stripped) {
expanded_list.push_back(dfiles[i].path); break;
}
}
}
}
} catch (std::regex_error& e) {
LOG(FATAL) << e.what() << " bad regex " << spattern
<< "This could due to compiler version, g++-4.9 is needed";
}
}
#endif // DMLC_USE_REGEX
}
}
for (size_t i = 0; i < expanded_list.size(); ++i) {
const URI& path = expanded_list[i];
FileInfo info = filesys_->GetPathInfo(path);
if (info.type == kDirectory) {
std::vector<FileInfo> dfiles;
filesys_->ListDirectory(info.path, &dfiles);
for (size_t i = 0; i < dfiles.size(); ++i) {
if (dfiles[i].size != 0 && dfiles[i].type == kFile) {
files_.push_back(dfiles[i]);
}
}
} else {
if (info.size != 0) {
files_.push_back(info);
}
}
}
CHECK_NE(files_.size(), 0U)
<< "Cannot find any files that matches the URI patternz " << uri;
}
size_t InputSplitBase::Read(void *ptr, size_t size) {
if (offset_begin_ >= offset_end_) return 0;
if (offset_curr_ + size > offset_end_) {
size = offset_end_ - offset_curr_;
}
if (size == 0) return 0;
size_t nleft = size;
char *buf = reinterpret_cast<char*>(ptr);
while (true) {
size_t n = fs_->Read(buf, nleft);
nleft -= n; buf += n;
offset_curr_ += n;
if (nleft == 0) break;
if (n == 0) {
if (offset_curr_ != file_offset_[file_ptr_ + 1]) {
LOG(ERROR) << "curr=" << offset_curr_
<< ",begin=" << offset_begin_
<< ",end=" << offset_end_
<< ",fileptr=" << file_ptr_
<< ",fileoffset=" << file_offset_[file_ptr_ + 1];
for (size_t i = 0; i < file_ptr_; ++i) {
LOG(ERROR) << "offset[" << i << "]=" << file_offset_[i];
}
LOG(FATAL) << "file offset not calculated correctly";
}
if (file_ptr_ + 1 >= files_.size()) break;
file_ptr_ += 1;
delete fs_;
fs_ = filesys_->OpenForRead(files_[file_ptr_].path);
}
}
return size - nleft;
}
bool InputSplitBase::ReadChunk(void *buf, size_t *size) {
size_t max_size = *size;
if (max_size <= overflow_.length()) {
*size = 0; return true;
}
if (overflow_.length() != 0) {
std::memcpy(buf, BeginPtr(overflow_), overflow_.length());
}
size_t olen = overflow_.length();
overflow_.resize(0);
size_t nread = this->Read(reinterpret_cast<char*>(buf) + olen,
max_size - olen);
nread += olen;
if (nread == 0) return false;
if (nread != max_size) {
*size = nread;
return true;
} else {
const char *bptr = reinterpret_cast<const char*>(buf);
// return the last position where a record starts
const char *bend = this->FindLastRecordBegin(bptr, bptr + max_size);
*size = bend - bptr;
overflow_.resize(max_size - *size);
if (overflow_.length() != 0) {
std::memcpy(BeginPtr(overflow_), bend, overflow_.length());
}
return true;
}
}
bool InputSplitBase::Chunk::Load(InputSplitBase *split, size_t buffer_size) {
if (buffer_size + 1 > data.size()) {
data.resize(buffer_size + 1);
}
while (true) {
// leave one tail chunk
size_t size = (data.size() - 1) * sizeof(size_t);
// set back to 0 for string safety
data.back() = 0;
if (!split->ReadChunk(BeginPtr(data), &size)) return false;
if (size == 0) {
data.resize(data.size() * 2);
} else {
begin = reinterpret_cast<char *>(BeginPtr(data));
end = begin + size;
break;
}
}
return true;
}
bool InputSplitBase::ExtractNextChunk(Blob *out_chunk, Chunk *chunk) {
if (chunk->begin == chunk->end) return false;
out_chunk->dptr = chunk->begin;
out_chunk->size = chunk->end - chunk->begin;
chunk->begin = chunk->end;
return true;
}
} // namespace io
} // namespace dmlc
//===== EXPANDED: ../dmlc-core/src/io/input_split_base.cc =====
//===== EXPANDING: ../dmlc-core/src/io/local_filesys.cc =====
// Copyright by Contributors
extern "C" {
}
#ifndef _MSC_VER
extern "C" {
}
#else
#define stat _stat64
#endif
//===== EXPANDING: ../dmlc-core/src/io/local_filesys.h =====
/*!
* Copyright (c) 2015 by Contributors
* \file local_filesys.h
* \brief local access module
* \author Tianqi Chen
*/
#ifndef DMLC_IO_LOCAL_FILESYS_H_
#define DMLC_IO_LOCAL_FILESYS_H_
namespace dmlc {
namespace io {
/*! \brief local file system */
class LocalFileSystem : public FileSystem {
public:
/*! \brief destructor */
virtual ~LocalFileSystem() {}
/*!
* \brief get information about a path
* \param path the path to the file
* \return the information about the file
*/
virtual FileInfo GetPathInfo(const URI &path);
/*!
* \brief list files in a directory
* \param path to the file
* \param out_list the output information about the files
*/
virtual void ListDirectory(const URI &path, std::vector<FileInfo> *out_list);
/*!
* \brief open a stream, will report error and exit if bad thing happens
* NOTE: the IStream can continue to work even when filesystem was destructed
* \param path path to file
* \param uri the uri of the input
* \param allow_null whether NULL can be returned, or directly report error
* \return the created stream, can be NULL when allow_null == true and file do not exist
*/
virtual SeekStream *Open(const URI &path,
const char* const flag,
bool allow_null);
/*!
* \brief open a seekable stream for read
* \param path the path to the file
* \param allow_null whether NULL can be returned, or directly report error
* \return the created stream, can be NULL when allow_null == true and file do not exist
*/
virtual SeekStream *OpenForRead(const URI &path, bool allow_null);
/*!
* \brief get a singleton of LocalFileSystem when needed
* \return a singleton instance
*/
inline static LocalFileSystem *GetInstance(void) {
static LocalFileSystem instance;
return &instance;
}
private:
LocalFileSystem() {}
};
} // namespace io
} // namespace dmlc
#endif // DMLC_IO_LOCAL_FILESYS_H_
//===== EXPANDED: ../dmlc-core/src/io/local_filesys.h =====
#if defined(__FreeBSD__)
#define fopen64 std::fopen
#endif
namespace dmlc {
namespace io {
/*! \brief implementation of file i/o stream */
class FileStream : public SeekStream {
public:
explicit FileStream(FILE *fp, bool use_stdio)
: fp_(fp), use_stdio_(use_stdio) {}
virtual ~FileStream(void) {
this->Close();
}
virtual size_t Read(void *ptr, size_t size) {
return std::fread(ptr, 1, size, fp_);
}
virtual void Write(const void *ptr, size_t size) {
CHECK(std::fwrite(ptr, 1, size, fp_) == size)
<< "FileStream.Write incomplete";
}
virtual void Seek(size_t pos) {
CHECK(!std::fseek(fp_, static_cast<long>(pos), SEEK_SET)); // NOLINT(*)
}
virtual size_t Tell(void) {
return std::ftell(fp_);
}
virtual bool AtEnd(void) const {
return std::feof(fp_) != 0;
}
inline void Close(void) {
if (fp_ != NULL && !use_stdio_) {
std::fclose(fp_); fp_ = NULL;
}
}
private:
std::FILE *fp_;
bool use_stdio_;
};
FileInfo LocalFileSystem::GetPathInfo(const URI &path) {
struct stat sb;
if (stat(path.name.c_str(), &sb) == -1) {
int errsv = errno;
LOG(FATAL) << "LocalFileSystem.GetPathInfo " << path.name
<< " Error:" << strerror(errsv);
}
FileInfo ret;
ret.path = path;
ret.size = sb.st_size;
if ((sb.st_mode & S_IFMT) == S_IFDIR) {
ret.type = kDirectory;
} else {
ret.type = kFile;
}
return ret;
}
void LocalFileSystem::ListDirectory(const URI &path, std::vector<FileInfo> *out_list) {
#ifndef _MSC_VER
DIR *dir = opendir(path.name.c_str());
if (dir == NULL) {
int errsv = errno;
LOG(FATAL) << "LocalFileSystem.ListDirectory " << path.str()
<<" error: " << strerror(errsv);
}
out_list->clear();
struct dirent *ent;
/* print all the files and directories within directory */
while ((ent = readdir(dir)) != NULL) {
if (!strcmp(ent->d_name, ".")) continue;
if (!strcmp(ent->d_name, "..")) continue;
URI pp = path;
if (pp.name[pp.name.length() - 1] != '/') {
pp.name += '/';
}
pp.name += ent->d_name;
out_list->push_back(GetPathInfo(pp));
}
closedir(dir);
#else
WIN32_FIND_DATA fd;
std::string pattern = path.name + "/*";
HANDLE handle = FindFirstFile(pattern.c_str(), &fd);
if (handle == INVALID_HANDLE_VALUE) {
int errsv = GetLastError();
LOG(FATAL) << "LocalFileSystem.ListDirectory " << path.str()
<< " error: " << strerror(errsv);
}
do {
if (strcmp(fd.cFileName, ".") && strcmp(fd.cFileName, "..")) {
URI pp = path;
char clast = pp.name[pp.name.length() - 1];
if (pp.name == ".") {
pp.name = fd.cFileName;
} else if (clast != '/' && clast != '\\') {
pp.name += '/';
pp.name += fd.cFileName;
}
out_list->push_back(GetPathInfo(pp));
}
} while (FindNextFile(handle, &fd));
FindClose(handle);
#endif
}
SeekStream *LocalFileSystem::Open(const URI &path,
const char* const mode,
bool allow_null) {
bool use_stdio = false;
FILE *fp = NULL;
const char *fname = path.name.c_str();
using namespace std;
#ifndef DMLC_DISABLE_STDIN
if (!strcmp(fname, "stdin")) {
use_stdio = true; fp = stdin;
}
if (!strcmp(fname, "stdout")) {
use_stdio = true; fp = stdout;
}
#endif
if (!strncmp(fname, "file://", 7)) fname += 7;
if (!use_stdio) {
std::string flag = mode;
if (flag == "w") flag = "wb";
if (flag == "r") flag = "rb";
fp = fopen64(fname, flag.c_str());
}
if (fp != NULL) {
return new FileStream(fp, use_stdio);
} else {
CHECK(allow_null) << " LocalFileSystem: fail to open \"" << path.str() << '\"';
return NULL;
}
}
SeekStream *LocalFileSystem::OpenForRead(const URI &path, bool allow_null) {
return Open(path, "r", allow_null);
}
} // namespace io
} // namespace dmlc
//===== EXPANDED: ../dmlc-core/src/io/local_filesys.cc =====
//===== EXPANDING: ../dmlc-core/src/data.cc =====
// Copyright by Contributors
//===== EXPANDING: ../dmlc-core/include/dmlc/data.h =====
/*!
* Copyright (c) 2015 by Contributors
* \file data.h
* \brief defines common input data structure,
* and interface for handling the input data
*/
#ifndef DMLC_DATA_H_
#define DMLC_DATA_H_
//===== EXPANDING: ../dmlc-core/include/dmlc/registry.h =====
/*!
* Copyright (c) 2015 by Contributors
* \file registry.h
* \brief Registry utility that helps to build registry singletons.
*/
#ifndef DMLC_REGISTRY_H_
#define DMLC_REGISTRY_H_
//===== EXPANDING: ../dmlc-core/include/dmlc/parameter.h =====
/*!
* Copyright (c) 2015 by Contributors
* \file parameter.h
* \brief Provide lightweight util to do parameter setup and checking.
*/
#ifndef DMLC_PARAMETER_H_
#define DMLC_PARAMETER_H_
//===== EXPANDING: ../dmlc-core/include/dmlc/json.h =====
/*!
* Copyright (c) 2015 by Contributors
* \file json.h
* \brief Lightweight JSON Reader/Writer that read save into C++ data structs.
* This includes STL composites and structures.
*/
#ifndef DMLC_JSON_H_
#define DMLC_JSON_H_
// This code requires C++11 to compile
#if DMLC_USE_CXX11
#if DMLC_STRICT_CXX11 && DMLC_ENABLE_RTTI
//===== EXPANDING: ../dmlc-core/include/dmlc/any.h =====
/*!
* Copyright (c) 2016 by Contributors
* \file any.h
* \brief Container to hold any data type.
*/
#ifndef DMLC_ANY_H_
#define DMLC_ANY_H_
// This code need c++11 to compile
namespace dmlc {
// forward declare any;
class any;
/*!
* Get a reference to content stored in the any as type T.
* This will cause an error if
* T does not match the type stored.
* This function is not part of std::any standard.
*
* \param src The source source any container.
* \return The reference of content
* \tparam T The type of the value to be fetched.
*/
template<typename T>
inline T& get(any& src); // NOLINT(*)
/*!
* Get the const reference content stored in the any as type T.
* This will cause an error if
* T does not match the type stored.
* This function is not part of std::any standard.
*
* \param src The source source any container.
* \return The reference of content
* \tparam T The type of the value to be fetched.
*/
template<typename T>
inline const T& get(const any& src);
/*!
* \brief An any class that is compatible to std::any in c++17.
*
* \code
* dmlc::any a = std::string("mydear"), b = 1;
* // get reference out and add it
* dmlc::get<int>(b) += 1;
* // a is now string
* LOG(INFO) << dmlc::get<std::string>(a);
* // a is now 2, the string stored will be properly destructed
* a = std::move(b);
* LOG(INFO) << dmlc::get<int>(a);
* \endcode
* \sa get
*/
class any {
public:
/*! \brief default constructor */
inline any() = default;
/*!
* \brief move constructor from another any
* \param other The other any to be moved
*/
inline any(any&& other); // NOLINT(*)
/*!
* \brief copy constructor
* \param other The other any to be copied
*/
inline any(const any& other); // NOLINT(*)
/*!
* \brief constructor from any types
* \param other The other types to be constructed into any.
* \tparam T The value type of other.
*/
template<typename T>
inline any(T&& other); // NOLINT(*)
/*! \brief destructor */
inline ~any();
/*!
* \brief assign operator from other
* \param other The other any to be copy or moved.
* \return self
*/
inline any& operator=(any&& other);
/*!
* \brief assign operator from other
* \param other The other any to be copy or moved.
* \return self
*/
inline any& operator=(const any& other);
/*!
* \brief assign operator from any type.
* \param other The other any to be copy or moved.
* \tparam T The value type of other.
* \return self
*/
template<typename T>
inline any& operator=(T&& other);
/*!
* \return whether the container is empty.
*/
inline bool empty() const;
/*!
* \return clear the content of container
*/
inline void clear();
/*!
* swap current content with other
* \param other The other data to be swapped.
*/
inline void swap(any& other); // NOLINT(*)
/*!
* \return The type_info about the stored type.
*/
inline const std::type_info& type() const;
private:
//! \cond Doxygen_Suppress
// declare of helper class
template<typename T>
class TypeOnHeap;
template<typename T>
class TypeOnStack;
template<typename T>
class TypeInfo;
// size of stack space, it takes 32 bytes for one any type.
static const size_t kStack = sizeof(void*) * 3;
static const size_t kAlign = sizeof(void*);
// container use dynamic storage only when space runs lager
union Data {
// stack space
std::aligned_storage<kStack, kAlign>::type stack;
// pointer to heap space
void* pheap;
};
// type specific information
struct Type {
// destructor function
void (*destroy)(Data* data);
// copy constructor
void (*create_from_data)(Data* dst, const Data& src);
// the type info function
const std::type_info* ptype_info;
};
// constant to check if data can be stored on heap.
template<typename T>
struct data_on_stack {
static const bool value = alignof(T) <= kAlign && sizeof(T) <= kStack;
};
// declare friend with
template<typename T>
friend T& get(any& src); // NOLINT(*)
template<typename T>
friend const T& get(const any& src);
// internal construct function
inline void construct(any&& other);
// internal construct function
inline void construct(const any& other);
// internal function to check if type is correct.
template<typename T>
inline void check_type() const;
// internal type specific information
const Type* type_{nullptr};
// internal data
Data data_;
};
template<typename T>
inline any::any(T&& other) {
typedef typename std::decay<T>::type DT;
if (std::is_same<DT, any>::value) {
this->construct(std::forward<T>(other));
} else {
static_assert(std::is_copy_constructible<DT>::value,
"Any can only hold value that is copy constructable");
type_ = TypeInfo<DT>::get_type();
if (data_on_stack<DT>::value) {
new (&(data_.stack)) DT(std::forward<T>(other));
} else {
data_.pheap = new DT(std::forward<T>(other));
}
}
}
inline any::any(any&& other) {
this->construct(std::move(other));
}
inline any::any(const any& other) {
this->construct(other);
}
inline void any::construct(any&& other) {
type_ = other.type_;
data_ = other.data_;
other.type_ = nullptr;
}
inline void any::construct(const any& other) {
type_ = other.type_;
if (type_ != nullptr) {
type_->create_from_data(&data_, other.data_);
}
}
inline any::~any() {
this->clear();
}
inline any& any::operator=(any&& other) {
any(std::move(other)).swap(*this);
return *this;
}
inline any& any::operator=(const any& other) {
any(other).swap(*this);
return *this;
}
template<typename T>
inline any& any::operator=(T&& other) {
any(std::forward<T>(other)).swap(*this);
return *this;
}
inline void any::swap(any& other) { // NOLINT(*)
std::swap(type_, other.type_);
std::swap(data_, other.data_);
}
inline void any::clear() {
if (type_ != nullptr) {
if (type_->destroy != nullptr) {
type_->destroy(&data_);
}
type_ = nullptr;
}
}
inline bool any::empty() const {
return type_ == nullptr;
}
inline const std::type_info& any::type() const {
if (type_ != nullptr) {
return *(type_->ptype_info);
} else {
return typeid(void);
}
}
template<typename T>
inline void any::check_type() const {
CHECK(type_ != nullptr)
<< "The any container is empty"
<< " requested=" << typeid(T).name();
CHECK(type_->ptype_info == &typeid(T))
<< "The stored type mismatch"
<< " stored=" << type_->ptype_info->name()
<< " requested=" << typeid(T).name();
}
template<typename T>
inline const T& get(const any& src) {
src.check_type<T>();
return *any::TypeInfo<T>::get_ptr(&(src.data_));
}
template<typename T>
inline T& get(any& src) { // NOLINT(*)
src.check_type<T>();
return *any::TypeInfo<T>::get_ptr(&(src.data_));
}
template<typename T>
class any::TypeOnHeap {
public:
inline static T* get_ptr(any::Data* data) {
return static_cast<T*>(data->pheap);
}
inline static const T* get_ptr(const any::Data* data) {
return static_cast<const T*>(data->pheap);
}
inline static void create_from_data(any::Data* dst, const any::Data& data) {
dst->pheap = new T(*get_ptr(&data));
}
inline static void destroy(Data* data) {
delete static_cast<T*>(data->pheap);
}
};
template<typename T>
class any::TypeOnStack {
public:
inline static T* get_ptr(any::Data* data) {
return reinterpret_cast<T*>(&(data->stack));
}
inline static const T* get_ptr(const any::Data* data) {
return reinterpret_cast<const T*>(&(data->stack));
}
inline static void create_from_data(any::Data* dst, const any::Data& data) {
new (&(dst->stack)) T(*get_ptr(&data));
}
inline static void destroy(Data* data) {
T* dptr = reinterpret_cast<T*>(&(data->stack));
dptr->~T();
}
};
template<typename T>
class any::TypeInfo
: public std::conditional<any::data_on_stack<T>::value,
any::TypeOnStack<T>,
any::TypeOnHeap<T> >::type {
public:
inline static const Type* get_type() {
static TypeInfo<T> tp;
return &(tp.type_);
}
private:
// local type
Type type_;
// constructor
TypeInfo() {
if (std::is_pod<T>::value) {
type_.destroy = nullptr;
} else {
type_.destroy = TypeInfo<T>::destroy;
}
type_.create_from_data = TypeInfo<T>::create_from_data;
type_.ptype_info = &typeid(T);
}
};
//! \endcond
} // namespace dmlc
#endif // DMLC_ANY_H_
//===== EXPANDED: ../dmlc-core/include/dmlc/any.h =====
#endif // DMLC_STRICT_CXX11
#endif // DMLC_USE_CXX11
namespace dmlc {
/*!
* \brief Lightweight JSON Reader to read any STL compositions and structs.
* The user need to know the schema of the
*
*/
class JSONReader {
public:
/*!
* \brief Constructor.
* \param is the input stream.
*/
explicit JSONReader(std::istream *is)
: is_(is),
line_count_r_(0),
line_count_n_(0) {}
/*!
* \brief Parse next JSON string.
* \param out_str the output string.
* \throw dmlc::Error when next token is not string
*/
inline void ReadString(std::string *out_str);
/*!
* \brief Read Number.
* \param out_value output value;
* \throw dmlc::Error when next token is not number of ValueType.
* \tparam ValueType type of the number
*/
template<typename ValueType>
inline void ReadNumber(ValueType *out_value);
/*!
* \brief Begin parsing an object.
* \code
* std::string key;
* // value can be any type that is json serializable.
* std::string value;
* reader->BeginObject();
* while (reader->NextObjectItem(&key)) {
* // do somthing to key value
* reader->Read(&value);
* }
* \endcode
*/
inline void BeginObject();
/*!
* \brief Begin parsing an array.
* \code
* // value can be any type that is json serializable.
* std::string value;
* reader->BeginArray();
* while (reader->NextObjectArrayItem(&value)) {
* // do somthing to value
* }
* \endcode
*/
inline void BeginArray();
/*!
* \brief Try to move to next object item.
* If this call is successful, user can proceed to call
* reader->Read to read in the value.
* \param out_key the key to the next object.
* \return true if the read is successful, false if we are at end of the object.
*/
inline bool NextObjectItem(std::string *out_key);
/*!
* \brief Try to read the next element in the array.
* If this call is successful, user can proceed to call
* reader->Read to read in the value.
* \return true if the read is successful, false if we are at end of the array.
*/
inline bool NextArrayItem();
/*!
* \brief Read next ValueType.
* \param out_value any STL or json readable type to be read
* \throw dmlc::Error when the read of ValueType is not successful.
* \tparam ValueType the data type to be read.
*/
template<typename ValueType>
inline void Read(ValueType *out_value);
/*! \return current line count */
inline std::string line_info() const {
char temp[64];
std::ostringstream os;
os << " Line " << std::max(line_count_r_, line_count_n_);
is_->getline(temp, 64);
os << ", around ^`" << temp << "`";
return os.str();
}
private:
/*! \brief internal reader stream */
std::istream *is_;
/*! \brief "\\r" counter */
size_t line_count_r_;
/*! \brief "\\n" counter */
size_t line_count_n_;
/*!
* \brief record how many element processed in
* current array/object scope.
*/
std::vector<size_t> scope_counter_;
/*!
* \brief Read next nonspace character.
* \return the next nonspace character.
*/
inline int NextNonSpace();
/*!
* \brief Read just before next nonspace but not read that.
* \return the next nonspace character.
*/
inline int PeekNextNonSpace();
};
/*!
* \brief Lightweight json to write any STL compositions.
*/
class JSONWriter {
public:
/*!
* \brief Constructor.
* \param os the output stream.
*/
explicit JSONWriter(std::ostream *os)
: os_(os) {}
/*!
* \brief Write a string that do not contain escape characters.
* \param s the string to be written.
*/
inline void WriteNoEscape(const std::string &s);
/*!
* \brief Write a string that can contain escape characters.
* \param s the string to be written.
*/
inline void WriteString(const std::string &s);
/*!
* \brief Write a string that can contain escape characters.
* \param v the value to be written.
* \tparam ValueType The value type to be written.
*/
template<typename ValueType>
inline void WriteNumber(const ValueType &v);
/*!
* \brief Start beginning of array.
* \param multi_line whether to start an multi_line array.
* \code
* writer->BeginArray();
* for (auto& v : vdata) {
* writer->WriteArrayItem(v);
* }
* writer->EndArray();
* \endcode
*/
inline void BeginArray(bool multi_line = true);
/*! \brief Finish writing an array. */
inline void EndArray();
/*!
* \brief Start beginning of array.
* \param multi_line whether to start an multi_line array.
* \code
* writer->BeginObject();
* for (auto& kv : vmap) {
* writer->WriteObjectKeyValue(kv.first, kv.second);
* }
* writer->EndObject();
* \endcode
*/
inline void BeginObject(bool multi_line = true);
/*! \brief Finish writing object. */
inline void EndObject();
/*!
* \brief Write key value pair in the object.
* \param key the key of the object.
* \param value the value of to be written.
* \tparam ValueType The value type to be written.
*/
template<typename ValueType>
inline void WriteObjectKeyValue(const std::string &key,
const ValueType &value);
/*!
* \brief Write seperator of array, before writing next element.
* User can proceed to call writer->Write to write next item
*/
inline void WriteArraySeperator();
/*!
* \brief Write value into array.
* \param value The value of to be written.
* \tparam ValueType The value type to be written.
*/
template<typename ValueType>
inline void WriteArrayItem(const ValueType &value);
/*!
* \brief Write value to json.
* \param value any STL or json readable that can be written.
* \tparam ValueType the data type to be write.
*/
template<typename ValueType>
inline void Write(const ValueType &value);
private:
/*! \brief Output stream */
std::ostream *os_;
/*!
* \brief record how many element processed in
* current array/object scope.
*/
std::vector<size_t> scope_counter_;
/*! \brief Record whether current is a multiline scope */
std::vector<bool> scope_multi_line_;
/*!
* \brief Write seperating space and newlines
*/
inline void WriteSeperator();
};
/*!
* \brief Helper class to read JSON into a class or struct object.
* \code
* struct Param {
* std::string name;
* int value;
* // define load function from JSON
* inline void Load(dmlc::JSONReader *reader) {
* dmlc::JSONStructReadHelper helper;
* helper.DeclareField("name", &name);
* helper.DeclareField("value", &value);
* helper.ReadAllFields(reader);
* }
* };
* \endcode
*/
class JSONObjectReadHelper {
public:
/*!
* \brief Declare field of type T
* \param key the key of the of field.
* \param addr address of the data type.
* \tparam T the data type to be read, must be STL composition of JSON serializable.
*/
template<typename T>
inline void DeclareField(const std::string &key, T *addr) {
DeclareFieldInternal(key, addr, false);
}
/*!
* \brief Declare optional field of type T
* \param key the key of the of field.
* \param addr address of the data type.
* \tparam T the data type to be read, must be STL composition of JSON serializable.
*/
template<typename T>
inline void DeclareOptionalField(const std::string &key, T *addr) {
DeclareFieldInternal(key, addr, true);
}
/*!
* \brief Read in all the declared fields.
* \param reader the JSONReader to read the json.
*/
inline void ReadAllFields(JSONReader *reader);
private:
/*!
* \brief Internal function to declare field.
* \param key the key of the of field.
* \param addr address of the data type.
* \param optional if set to true, no error will be reported if the key is not presented.
* \tparam T the data type to be read, must be STL composition of JSON serializable.
*/
template<typename T>
inline void DeclareFieldInternal(const std::string &key, T *addr, bool optional);
/*!
* \brief The internal reader function.
* \param reader The reader to read.
* \param addr The memory address to read.
*/
template<typename T>
inline static void ReaderFunction(JSONReader *reader, void *addr);
/*! \brief callback type to reader function */
typedef void (*ReadFunction)(JSONReader *reader, void *addr);
/*! \brief internal data entry */
struct Entry {
/*! \brief the reader function */
ReadFunction func;
/*! \brief the address to read */
void *addr;
/*! \brief whether it is optional */
bool optional;
};
/*! \brief the internal map of reader callbacks */
std::map<std::string, Entry> map_;
};
#define DMLC_JSON_ENABLE_ANY_VAR_DEF(KeyName) \
static DMLC_ATTRIBUTE_UNUSED ::dmlc::json::AnyJSONManager& \
__make_AnyJSONType ## _ ## KeyName ## __
/*!
* \def DMLC_JSON_ENABLE_ANY
* \brief Macro to enable save/load JSON of dmlc:: whose actual type is Type.
* Any type will be saved as json array [KeyName, content]
*
* \param Type The type to be registered.
* \param KeyName The Type key assigned to the type, must be same during load.
*/
#define DMLC_JSON_ENABLE_ANY(Type, KeyName) \
DMLC_STR_CONCAT(DMLC_JSON_ENABLE_ANY_VAR_DEF(KeyName), __COUNTER__) = \
::dmlc::json::AnyJSONManager::Global()->EnableType<Type>(#KeyName) \
//! \cond Doxygen_Suppress
namespace json {
/*!
* \brief generic serialization handler
* \tparam T the type to be serialized
*/
template<typename T>
struct Handler;
template<typename ValueType>
struct NumericHandler {
inline static void Write(JSONWriter *writer, const ValueType &value) {
writer->WriteNumber<ValueType>(value);
}
inline static void Read(JSONReader *reader, ValueType *value) {
reader->ReadNumber<ValueType>(value);
}
};
template<typename ContainerType>
struct ArrayHandler {
inline static void Write(JSONWriter *writer, const ContainerType &array) {
typedef typename ContainerType::value_type ElemType;
writer->BeginArray(array.size() > 10 || !dmlc::is_pod<ElemType>::value);
for (typename ContainerType::const_iterator it = array.begin();
it != array.end(); ++it) {
writer->WriteArrayItem(*it);
}
writer->EndArray();
}
inline static void Read(JSONReader *reader, ContainerType *array) {
typedef typename ContainerType::value_type ElemType;
array->clear();
reader->BeginArray();
while (reader->NextArrayItem()) {
ElemType value;
Handler<ElemType>::Read(reader, &value);
array->insert(array->end(), value);
}
}
};
template<typename ContainerType>
struct MapHandler{
inline static void Write(JSONWriter *writer, const ContainerType &map) {
writer->BeginObject(map.size() > 1);
for (typename ContainerType::const_iterator it = map.begin(); it != map.end(); ++it) {
writer->WriteObjectKeyValue(it->first, it->second);
}
writer->EndObject();
}
inline static void Read(JSONReader *reader, ContainerType *map) {
typedef typename ContainerType::mapped_type ElemType;
map->clear();
reader->BeginObject();
std::string key;
while (reader->NextObjectItem(&key)) {
ElemType value;
reader->Read(&value);
(*map)[key] = value;
}
}
};
template<typename T>
struct CommonJSONSerializer {
inline static void Write(JSONWriter *writer, const T &value) {
value.Save(writer);
}
inline static void Read(JSONReader *reader, T *value) {
value->Load(reader);
}
};
template<>
struct Handler<std::string> {
inline static void Write(JSONWriter *writer, const std::string &value) {
writer->WriteString(value);
}
inline static void Read(JSONReader *reader, std::string *str) {
reader->ReadString(str);
}
};
template<typename T>
struct Handler<std::vector<T> > : public ArrayHandler<std::vector<T> > {
};
template<typename K, typename V>
struct Handler<std::pair<K, V> > {
inline static void Write(JSONWriter *writer, const std::pair<K, V> &kv) {
writer->BeginArray();
writer->WriteArrayItem(kv.first);
writer->WriteArrayItem(kv.second);
writer->EndArray();
}
inline static void Read(JSONReader *reader, std::pair<K, V> *kv) {
reader->BeginArray();
CHECK(reader->NextArrayItem())
<< "Expect array of length 2";
Handler<K>::Read(reader, &(kv->first));
CHECK(reader->NextArrayItem())
<< "Expect array of length 2";
Handler<V>::Read(reader, &(kv->second));
CHECK(!reader->NextArrayItem())
<< "Expect array of length 2";
}
};
template<typename T>
struct Handler<std::list<T> > : public ArrayHandler<std::list<T> > {
};
template<typename V>
struct Handler<std::map<std::string, V> > : public MapHandler<std::map<std::string, V> > {
};
#if DMLC_USE_CXX11
template<typename V>
struct Handler<std::unordered_map<std::string, V> >
: public MapHandler<std::unordered_map<std::string, V> > {
};
#endif // DMLC_USE_CXX11
template<typename T>
struct Handler {
inline static void Write(JSONWriter *writer, const T &data) {
typedef typename dmlc::IfThenElseType<dmlc::is_arithmetic<T>::value,
NumericHandler<T>,
CommonJSONSerializer<T> >::Type THandler;
THandler::Write(writer, data);
}
inline static void Read(JSONReader *reader, T *data) {
typedef typename dmlc::IfThenElseType<dmlc::is_arithmetic<T>::value,
NumericHandler<T>,
CommonJSONSerializer<T> >::Type THandler;
THandler::Read(reader, data);
}
};
#if DMLC_STRICT_CXX11 && DMLC_ENABLE_RTTI
// Manager to store json serialization strategy.
class AnyJSONManager {
public:
template<typename T>
inline AnyJSONManager& EnableType(const std::string& type_name) { // NOLINT(*)
std::type_index tp = std::type_index(typeid(T));
if (type_name_.count(tp) != 0) {
CHECK(type_name_.at(tp) == type_name)
<< "Type has already been registered as another typename " << type_name_.at(tp);
return *this;
}
CHECK(type_map_.count(type_name) == 0)
<< "Type name " << type_name << " already registered in registry";
Entry e;
e.read = ReadAny<T>;
e.write = WriteAny<T>;
type_name_[tp] = type_name;
type_map_[type_name] = e;
return *this;
}
// return global singleton
inline static AnyJSONManager* Global() {
static AnyJSONManager inst;
return &inst;
}
private:
AnyJSONManager() {}
template<typename T>
inline static void WriteAny(JSONWriter *writer, const any &data) {
writer->Write(dmlc::get<T>(data));
}
template<typename T>
inline static void ReadAny(JSONReader *reader, any* data) {
T temp;
reader->Read(&temp);
*data = std::move(temp);
}
// data entry to store vtable for any type
struct Entry {
void (*read)(JSONReader* reader, any *data);
void (*write)(JSONWriter* reader, const any& data);
};
template<typename T>
friend struct Handler;
std::unordered_map<std::type_index, std::string> type_name_;
std::unordered_map<std::string, Entry> type_map_;
};
template<>
struct Handler<any> {
inline static void Write(JSONWriter *writer, const any &data) {
std::unordered_map<std::type_index, std::string>&
nmap = AnyJSONManager::Global()->type_name_;
std::type_index id = std::type_index(data.type());
auto it = nmap.find(id);
CHECK(it != nmap.end() && it->first == id)
<< "Type " << id.name() << " has not been registered via DMLC_JSON_ENABLE_ANY";
std::string type_name = it->second;
AnyJSONManager::Entry e = AnyJSONManager::Global()->type_map_.at(type_name);
writer->BeginArray(false);
writer->WriteArrayItem(type_name);
writer->WriteArraySeperator();
e.write(writer, data);
writer->EndArray();
}
inline static void Read(JSONReader *reader, any *data) {
std::string type_name;
reader->BeginArray();
CHECK(reader->NextArrayItem()) << "invalid any json format";
Handler<std::string>::Read(reader, &type_name);
std::unordered_map<std::string, AnyJSONManager::Entry>&
tmap = AnyJSONManager::Global()->type_map_;
auto it = tmap.find(type_name);
CHECK(it != tmap.end() && it->first == type_name)
<< "Typename " << type_name << " has not been registered via DMLC_JSON_ENABLE_ANY";
AnyJSONManager::Entry e = it->second;
CHECK(reader->NextArrayItem()) << "invalid any json format";
e.read(reader, data);
CHECK(!reader->NextArrayItem()) << "invalid any json format";
}
};
#endif // DMLC_STRICT_CXX11
} // namespace json
// implementations of JSONReader/Writer
inline int JSONReader::NextNonSpace() {
int ch;
do {
ch = is_->get();
if (ch == '\n') ++line_count_n_;
if (ch == '\r') ++line_count_r_;
} while (isspace(ch));
return ch;
}
inline int JSONReader::PeekNextNonSpace() {
int ch;
while (true) {
ch = is_->peek();
if (ch == '\n') ++line_count_n_;
if (ch == '\r') ++line_count_r_;
if (!isspace(ch)) break;
is_->get();
}
return ch;
}
inline void JSONReader::ReadString(std::string *out_str) {
int ch = NextNonSpace();
CHECK_EQ(ch, '\"')
<< "Error at" << line_info()
<< ", Expect \'\"\' but get \'" << static_cast<char>(ch) << '\'';
std::ostringstream os;
while (true) {
ch = is_->get();
if (ch == '\\') {
char sch = static_cast<char>(is_->get());
switch (sch) {
case 'r': os << "\r"; break;
case 'n': os << "\n"; break;
case '\\': os << "\\"; break;
case '\t': os << "\t"; break;
case '\"': os << "\""; break;
default: LOG(FATAL) << "unknown string escape \\" << sch;
}
} else {
if (ch == '\"') break;
os << static_cast<char>(ch);
}
if (ch == EOF || ch == '\r' || ch == '\n') {
LOG(FATAL)
<< "Error at" << line_info()
<< ", Expect \'\"\' but reach end of line ";
}
}
*out_str = os.str();
}
template<typename ValueType>
inline void JSONReader::ReadNumber(ValueType *out_value) {
*is_ >> *out_value;
CHECK(!is_->fail())
<< "Error at" << line_info()
<< ", Expect number";
}
inline void JSONReader::BeginObject() {
int ch = NextNonSpace();
CHECK_EQ(ch, '{')
<< "Error at" << line_info()
<< ", Expect \'{\' but get \'" << static_cast<char>(ch) << '\'';
scope_counter_.push_back(0);
}
inline void JSONReader::BeginArray() {
int ch = NextNonSpace();
CHECK_EQ(ch, '[')
<< "Error at" << line_info()
<< ", Expect \'{\' but get \'" << static_cast<char>(ch) << '\'';
scope_counter_.push_back(0);
}
inline bool JSONReader::NextObjectItem(std::string *out_key) {
bool next = true;
if (scope_counter_.back() != 0) {
int ch = NextNonSpace();
if (ch == EOF) {
next = false;
} else if (ch == '}') {
next = false;
} else {
CHECK_EQ(ch, ',')
<< "Error at" << line_info()
<< ", JSON object expect \'}\' or \',\' \'" << static_cast<char>(ch) << '\'';
}
} else {
int ch = PeekNextNonSpace();
if (ch == '}') {
is_->get();
next = false;
}
}
if (!next) {
scope_counter_.pop_back();
return false;
} else {
scope_counter_.back() += 1;
ReadString(out_key);
int ch = NextNonSpace();
CHECK_EQ(ch, ':')
<< "Error at" << line_info()
<< ", Expect \':\' but get \'" << static_cast<char>(ch) << '\'';
return true;
}
}
inline bool JSONReader::NextArrayItem() {
bool next = true;
if (scope_counter_.back() != 0) {
int ch = NextNonSpace();
if (ch == EOF) {
next = false;
} else if (ch == ']') {
next = false;
} else {
CHECK_EQ(ch, ',')
<< "Error at" << line_info()
<< ", JSON array expect \']\' or \',\'. Get \'" << static_cast<char>(ch) << "\' instead";
}
} else {
int ch = PeekNextNonSpace();
if (ch == ']') {
is_->get();
next = false;
}
}
if (!next) {
scope_counter_.pop_back();
return false;
} else {
scope_counter_.back() += 1;
return true;
}
}
template<typename ValueType>
inline void JSONReader::Read(ValueType *out_value) {
json::Handler<ValueType>::Read(this, out_value);
}
inline void JSONWriter::WriteNoEscape(const std::string &s) {
*os_ << '\"' << s << '\"';
}
inline void JSONWriter::WriteString(const std::string &s) {
std::ostream &os = *os_;
os << '\"';
for (size_t i = 0; i < s.length(); ++i) {
char ch = s[i];
switch (ch) {
case '\r': os << "\\r"; break;
case '\n': os << "\\n"; break;
case '\\': os << "\\\\"; break;
case '\t': os << "\\t"; break;
case '\"': os << "\\\""; break;
default: os << ch;
}
}
os << '\"';
}
template<typename ValueType>
inline void JSONWriter::WriteNumber(const ValueType &v) {
*os_ << v;
}
inline void JSONWriter::BeginArray(bool multi_line) {
*os_ << '[';
scope_multi_line_.push_back(multi_line);
scope_counter_.push_back(0);
}
inline void JSONWriter::EndArray() {
CHECK_NE(scope_multi_line_.size(), 0U);
CHECK_NE(scope_counter_.size(), 0U);
bool newline = scope_multi_line_.back();
size_t nelem = scope_counter_.back();
scope_multi_line_.pop_back();
scope_counter_.pop_back();
if (newline && nelem != 0) WriteSeperator();
*os_ << ']';
}
inline void JSONWriter::BeginObject(bool multi_line) {
*os_ << "{";
scope_multi_line_.push_back(multi_line);
scope_counter_.push_back(0);
}
inline void JSONWriter::EndObject() {
CHECK_NE(scope_multi_line_.size(), 0U);
CHECK_NE(scope_counter_.size(), 0U);
bool newline = scope_multi_line_.back();
size_t nelem = scope_counter_.back();
scope_multi_line_.pop_back();
scope_counter_.pop_back();
if (newline && nelem != 0) WriteSeperator();
*os_ << '}';
}
template<typename ValueType>
inline void JSONWriter::WriteObjectKeyValue(const std::string &key,
const ValueType &value) {
std::ostream &os = *os_;
if (scope_counter_.back() == 0) {
WriteSeperator();
os << '\"' << key << "\": ";
} else {
os << ", ";
WriteSeperator();
os << '\"' << key << "\": ";
}
scope_counter_.back() += 1;
json::Handler<ValueType>::Write(this, value);
}
inline void JSONWriter::WriteArraySeperator() {
std::ostream &os = *os_;
if (scope_counter_.back() != 0) {
os << ", ";
}
scope_counter_.back() += 1;
WriteSeperator();
}
template<typename ValueType>
inline void JSONWriter::WriteArrayItem(const ValueType &value) {
this->WriteArraySeperator();
json::Handler<ValueType>::Write(this, value);
}
template<typename ValueType>
inline void JSONWriter::Write(const ValueType &value) {
size_t nscope = scope_multi_line_.size();
json::Handler<ValueType>::Write(this, value);
CHECK_EQ(nscope, scope_multi_line_.size())
<< "Uneven scope, did you call EndArray/EndObject after each BeginObject/Array?";
}
inline void JSONWriter::WriteSeperator() {
if (scope_multi_line_.size() == 0 || scope_multi_line_.back()) {
*os_ << '\n' << std::string(scope_multi_line_.size() * 2, ' ');
}
}
inline void JSONObjectReadHelper::ReadAllFields(JSONReader *reader) {
reader->BeginObject();
std::map<std::string, int> visited;
std::string key;
while (reader->NextObjectItem(&key)) {
if (map_.count(key) != 0) {
Entry e = map_[key];
(*e.func)(reader, e.addr);
visited[key] = 0;
} else {
std::ostringstream os;
os << "JSONReader: Unknown field " << key << ", candidates are: \n";
for (std::map<std::string, Entry>::iterator
it = map_.begin(); it != map_.end(); ++it) {
os << '\"' <<it->first << "\"\n";
}
LOG(FATAL) << os.str();
}
}
if (visited.size() != map_.size()) {
for (std::map<std::string, Entry>::iterator
it = map_.begin(); it != map_.end(); ++it) {
if (it->second.optional) continue;
CHECK_NE(visited.count(it->first), 0U)
<< "JSONReader: Missing field \"" << it->first << "\"\n At "
<< reader->line_info();
}
}
}
template<typename T>
inline void JSONObjectReadHelper::ReaderFunction(JSONReader *reader, void *addr) {
json::Handler<T>::Read(reader, static_cast<T*>(addr));
}
template<typename T>
inline void JSONObjectReadHelper::
DeclareFieldInternal(const std::string &key, T *addr, bool optional) {
CHECK_EQ(map_.count(key), 0U)
<< "Adding duplicate field " << key;
Entry e;
e.func = ReaderFunction<T>;
e.addr = static_cast<void*>(addr);
e.optional = optional;
map_[key] = e;
}
//! \endcond
} // namespace dmlc
#endif // DMLC_JSON_H_
//===== EXPANDED: ../dmlc-core/include/dmlc/json.h =====
//===== EXPANDING: ../dmlc-core/include/dmlc/optional.h =====
/*!
* Copyright (c) 2016 by Contributors
* \file optional.h
* \brief Container to hold optional data.
*/
#ifndef DMLC_OPTIONAL_H_
#define DMLC_OPTIONAL_H_
namespace dmlc {
/*! \brief dummy type for assign null to optional */
struct nullopt_t {
#if defined(_MSC_VER) && _MSC_VER < 1900
/*! \brief dummy constructor */
explicit nullopt_t(int a) {}
#else
/*! \brief dummy constructor */
constexpr nullopt_t(int a) {}
#endif
};
/*! Assign null to optional: optional<T> x = nullopt; */
constexpr const nullopt_t nullopt = nullopt_t(0);
/*!
* \brief c++17 compatible optional class.
*
* At any time an optional<T> instance either
* hold no value (string representation "None")
* or hold a value of type T.
*/
template<typename T>
class optional {
public:
/*! \brief construct an optional object that contains no value */
optional() : is_none(true) {}
/*! \brief construct an optional object with value */
explicit optional(const T& value) {
is_none = false;
new (&val) T(value);
}
/*! \brief construct an optional object with another optional object */
optional(const optional<T>& other) {
is_none = other.is_none;
if (!is_none) {
new (&val) T(other.value());
}
}
/*! \brief deconstructor */
~optional() {
if (!is_none) {
reinterpret_cast<T*>(&val)->~T();
}
}
/*! \brief swap two optional */
void swap(optional<T>& other) {
std::swap(val, other.val);
std::swap(is_none, other.is_none);
}
/*! \brief set this object to hold value
* \param value the value to hold
* \return return self to support chain assignment
*/
optional<T>& operator=(const T& value) {
(optional<T>(value)).swap(*this);
return *this;
}
/*! \brief set this object to hold the same value with other
* \param other the other object
* \return return self to support chain assignment
*/
optional<T>& operator=(const optional<T> &other) {
(optional<T>(other)).swap(*this);
return *this;
}
/*! \brief clear the value this object is holding.
* optional<T> x = nullopt;
*/
optional<T>& operator=(nullopt_t) {
(optional<T>()).swap(*this);
return *this;
}
/*! \brief non-const dereference operator */
T& operator*() { // NOLINT(*)
return *reinterpret_cast<T*>(&val);
}
/*! \brief const dereference operator */
const T& operator*() const {
return *reinterpret_cast<const T*>(&val);
}
/*! \brief return the holded value.
* throws std::logic_error if holding no value
*/
const T& value() const {
if (is_none) {
throw std::logic_error("bad optional access");
}
return *reinterpret_cast<const T*>(&val);
}
/*! \brief whether this object is holding a value */
explicit operator bool() const { return !is_none; }
private:
// whether this is none
bool is_none;
// on stack storage of value
typename std::aligned_storage<sizeof(T), alignof(T)>::type val;
};
/*! \brief serialize an optional object to string.
*
* \code
* dmlc::optional<int> x;
* std::cout << x; // None
* x = 0;
* std::cout << x; // 0
* \endcode
*
* \param os output stream
* \param t source optional<T> object
* \return output stream
*/
template<typename T>
std::ostream &operator<<(std::ostream &os, const optional<T> &t) {
if (t) {
os << *t;
} else {
os << "None";
}
return os;
}
/*! \brief parse a string object into optional<T>
*
* \code
* dmlc::optional<int> x;
* std::string s1 = "1";
* std::istringstream is1(s1);
* s1 >> x; // x == optional<int>(1)
*
* std::string s2 = "None";
* std::istringstream is2(s2);
* s2 >> x; // x == optional<int>()
* \endcode
*
* \param is input stream
* \param t target optional<T> object
* \return input stream
*/
template<typename T>
std::istream &operator>>(std::istream &is, optional<T> &t) {
char buf[4];
std::streampos origin = is.tellg();
is.read(buf, 4);
if (is.fail() || buf[0] != 'N' || buf[1] != 'o' ||
buf[2] != 'n' || buf[3] != 'e') {
is.clear();
is.seekg(origin);
T x;
is >> x;
t = x;
} else {
t = nullopt;
}
return is;
}
/*! \brief description for optional int */
DMLC_DECLARE_TYPE_NAME(optional<int>, "int or None");
} // namespace dmlc
#endif // DMLC_OPTIONAL_H_
//===== EXPANDED: ../dmlc-core/include/dmlc/optional.h =====
namespace dmlc {
// this file is backward compatible with non-c++11
/*! \brief Error throwed by parameter checking */
struct ParamError : public dmlc::Error {
/*!
* \brief constructor
* \param msg error message
*/
explicit ParamError(const std::string &msg)
: dmlc::Error(msg) {}
};
/*!
* \brief Get environment variable with default.
* \param key the name of environment variable.
* \param default_value the default value of environment vriable.
* \return The value received
*/
template<typename ValueType>
inline ValueType GetEnv(const char *key,
ValueType default_value);
/*! \brief internal namespace for parameter manangement */
namespace parameter {
// forward declare ParamManager
class ParamManager;
// forward declare FieldAccessEntry
class FieldAccessEntry;
// forward declare FieldEntry
template<typename DType>
class FieldEntry;
// forward declare ParamManagerSingleton
template<typename PType>
struct ParamManagerSingleton;
/*! \brief option in parameter initialization */
enum ParamInitOption {
/*! \brief allow unknown parameters */
kAllowUnknown,
/*! \brief need to match exact parameters */
kAllMatch,
/*! \brief allow unmatched hidden field with format __*__ */
kAllowHidden
};
} // namespace parameter
/*!
* \brief Information about a parameter field in string representations.
*/
struct ParamFieldInfo {
/*! \brief name of the field */
std::string name;
/*! \brief type of the field in string format */
std::string type;
/*!
* \brief detailed type information string
* This include the default value, enum constran and typename.
*/
std::string type_info_str;
/*! \brief detailed description of the type */
std::string description;
};
/*!
* \brief Parameter is the base type every parameter struct should inheritate from
* The following code is a complete example to setup parameters.
* \code
* struct Param : public dmlc::Parameter<Param> {
* float learning_rate;
* int num_hidden;
* std::string name;
* // declare parameters in header file
* DMLC_DECLARE_PARAMETER(Param) {
* DMLC_DECLARE_FIELD(num_hidden).set_range(0, 1000);
* DMLC_DECLARE_FIELD(learning_rate).set_default(0.01f);
* DMLC_DECLARE_FIELD(name).set_default("hello");
* }
* };
* // register it in cc file
* DMLC_REGISTER_PARAMETER(Param);
* \endcode
*
* After that, the Param struct will get all the functions defined in Parameter.
* \tparam PType the type of parameter struct
*
* \sa DMLC_DECLARE_FIELD, DMLC_REGISTER_PARAMETER, DMLC_DECLARE_PARAMETER
*/
template<typename PType>
struct Parameter {
public:
/*!
* \brief initialize the parameter by keyword arguments.
* This function will initialize the parameter struct, check consistency
* and throw error if something wrong happens.
*
* \param kwargs map of keyword arguments, or vector of pairs
* \parma option The option on initialization.
* \tparam Container container type
* \throw ParamError when something go wrong.
*/
template<typename Container>
inline void Init(const Container &kwargs,
parameter::ParamInitOption option = parameter::kAllowHidden) {
PType::__MANAGER__()->RunInit(static_cast<PType*>(this),
kwargs.begin(), kwargs.end(),
NULL,
option);
}
/*!
* \brief initialize the parameter by keyword arguments.
* This is same as Init, but allow unknown arguments.
*
* \param kwargs map of keyword arguments, or vector of pairs
* \tparam Container container type
* \throw ParamError when something go wrong.
* \return vector of pairs of unknown arguments.
*/
template<typename Container>
inline std::vector<std::pair<std::string, std::string> >
InitAllowUnknown(const Container &kwargs) {
std::vector<std::pair<std::string, std::string> > unknown;
PType::__MANAGER__()->RunInit(static_cast<PType*>(this),
kwargs.begin(), kwargs.end(),
&unknown, parameter::kAllowUnknown);
return unknown;
}
/*!
* \brief Return a dictionary representation of the parameters
* \return A dictionary that maps key -> value
*/
inline std::map<std::string, std::string> __DICT__() const {
std::vector<std::pair<std::string, std::string> > vec
= PType::__MANAGER__()->GetDict(this->head());
return std::map<std::string, std::string>(vec.begin(), vec.end());
}
/*!
* \brief Write the parameters in JSON format.
* \param writer JSONWriter used for writing.
*/
inline void Save(dmlc::JSONWriter *writer) const {
writer->Write(this->__DICT__());
}
/*!
* \brief Load the parameters from JSON.
* \param reader JSONReader used for loading.
* \throw ParamError when something go wrong.
*/
inline void Load(dmlc::JSONReader *reader) {
std::map<std::string, std::string> kwargs;
reader->Read(&kwargs);
this->Init(kwargs);
}
/*!
* \brief Get the fields of the parameters.
* \return List of ParamFieldInfo of each field.
*/
inline static std::vector<ParamFieldInfo> __FIELDS__() {
return PType::__MANAGER__()->GetFieldInfo();
}
/*!
* \brief Print docstring of the parameter
* \return the printed docstring
*/
inline static std::string __DOC__() {
std::ostringstream os;
PType::__MANAGER__()->PrintDocString(os);
return os.str();
}
protected:
/*!
* \brief internal function to allow declare of a parameter memember
* \param manager the parameter manager
* \param key the key name of the parameter
* \param ref the reference to the parameter in the struct.
*/
template<typename DType>
inline parameter::FieldEntry<DType>& DECLARE(
parameter::ParamManagerSingleton<PType> *manager,
const std::string &key, DType &ref) { // NOLINT(*)
parameter::FieldEntry<DType> *e =
new parameter::FieldEntry<DType>();
e->Init(key, this->head(), ref);
manager->manager.AddEntry(key, e);
return *e;
}
private:
/*! \return Get head pointer of child structure */
inline PType *head() const {
return static_cast<PType*>(const_cast<Parameter<PType>*>(this));
}
};
//! \cond Doxygen_Suppress
/*!
* \brief macro used to declare parameter
*
* Example:
* \code
* struct Param : public dmlc::Parameter<Param> {
* // declare parameters in header file
* DMLC_DECLARE_PARAMETER(Param) {
* // details of declarations
* }
* };
* \endcode
*
* This macro need to be put in a source file so that registeration only happens once.
* Refer to example code in Parameter for details
*
* \param PType the name of parameter struct.
* \sa Parameter
*/
#define DMLC_DECLARE_PARAMETER(PType) \
static ::dmlc::parameter::ParamManager *__MANAGER__(); \
inline void __DECLARE__(::dmlc::parameter::ParamManagerSingleton<PType> *manager) \
/*!
* \brief macro to declare fields
* \param FieldName the name of the field.
*/
#define DMLC_DECLARE_FIELD(FieldName) this->DECLARE(manager, #FieldName, FieldName)
/*!
* \brief macro to declare alias of a fields
* \param FieldName the name of the field.
* \param AliasName the name of the alias, must be declared after the field is declared.
*/
#define DMLC_DECLARE_ALIAS(FieldName, AliasName) manager->manager.AddAlias(#FieldName, #AliasName)
/*!
* \brief Macro used to register parameter.
*
* This macro need to be put in a source file so that registeration only happens once.
* Refer to example code in Parameter for details
* \param PType the type of parameter struct.
* \sa Parameter
*/
#define DMLC_REGISTER_PARAMETER(PType) \
::dmlc::parameter::ParamManager *PType::__MANAGER__() { \
static ::dmlc::parameter::ParamManagerSingleton<PType> inst(#PType); \
return &inst.manager; \
} \
static DMLC_ATTRIBUTE_UNUSED ::dmlc::parameter::ParamManager& \
__make__ ## PType ## ParamManager__ = \
(*PType::__MANAGER__()) \
//! \endcond
/*!
* \brief internal namespace for parameter manangement
* There is no need to use it directly in normal case
*/
namespace parameter {
/*!
* \brief FieldAccessEntry interface to help manage the parameters
* Each entry can be used to access one parameter in the Parameter struct.
*
* This is an internal interface used that is used to manage parameters
*/
class FieldAccessEntry {
public:
FieldAccessEntry()
: has_default_(false) {}
/*! \brief destructor */
virtual ~FieldAccessEntry() {}
/*!
* \brief set the default value.
* \param head the pointer to the head of the struct
* \throw error if no default is presented
*/
virtual void SetDefault(void *head) const = 0;
/*!
* \brief set the parameter by string value
* \param head the pointer to the head of the struct
* \param value the value to be set
*/
virtual void Set(void *head, const std::string &value) const = 0;
// check if value is OK
virtual void Check(void *head) const {}
/*!
* \brief get the string representation of value.
* \param head the pointer to the head of the struct
*/
virtual std::string GetStringValue(void *head) const = 0;
/*!
* \brief Get field information
* \return the corresponding field information
*/
virtual ParamFieldInfo GetFieldInfo() const = 0;
protected:
/*! \brief whether this parameter have default value */
bool has_default_;
/*! \brief positional index of parameter in struct */
size_t index_;
/*! \brief parameter key name */
std::string key_;
/*! \brief parameter type */
std::string type_;
/*! \brief description of the parameter */
std::string description_;
/*!
* \brief print string representation of default value
* \parma os the stream to print the docstring to.
*/
virtual void PrintDefaultValueString(std::ostream &os) const = 0; // NOLINT(*)
// allow ParamManager to modify self
friend class ParamManager;
};
/*!
* \brief manager class to handle parameter structure for each type
* An manager will be created for each parameter structure.
*/
class ParamManager {
public:
/*! \brief destructor */
~ParamManager() {
for (size_t i = 0; i < entry_.size(); ++i) {
delete entry_[i];
}
}
/*!
* \brief find the access entry by parameter key
* \param key the key of the parameter.
* \return pointer to FieldAccessEntry, NULL if nothing is found.
*/
inline FieldAccessEntry *Find(const std::string &key) const {
std::map<std::string, FieldAccessEntry*>::const_iterator it =
entry_map_.find(key);
if (it == entry_map_.end()) return NULL;
return it->second;
}
/*!
* \brief set parameter by keyword arguments.
* \param head head to the parameter field.
* \param begin begin iterator of original kwargs
* \param end end iterator of original kwargs
* \param unknown_args optional, used to hold unknown arguments
* When it is specified, unknown arguments will be stored into here, instead of raise an error
* \tparam RandomAccessIterator iterator type
* \throw ParamError when there is unknown argument and unknown_args == NULL, or required argument is missing.
*/
template<typename RandomAccessIterator>
inline void RunInit(void *head,
RandomAccessIterator begin,
RandomAccessIterator end,
std::vector<std::pair<std::string, std::string> > *unknown_args,
parameter::ParamInitOption option) const {
std::set<FieldAccessEntry*> selected_args;
for (RandomAccessIterator it = begin; it != end; ++it) {
FieldAccessEntry *e = Find(it->first);
if (e != NULL) {
e->Set(head, it->second);
e->Check(head);
selected_args.insert(e);
} else {
if (unknown_args != NULL) {
unknown_args->push_back(*it);
} else {
if (option != parameter::kAllowUnknown) {
if (option == parameter::kAllowHidden &&
it->first.length() > 4 &&
it->first.find("__") == 0 &&
it->first.rfind("__") == it->first.length()-2) {
continue;
}
std::ostringstream os;
os << "Cannot find argument \'" << it->first << "\', Possible Arguments:\n";
os << "----------------\n";
PrintDocString(os);
throw dmlc::ParamError(os.str());
}
}
}
}
for (std::map<std::string, FieldAccessEntry*>::const_iterator it = entry_map_.begin();
it != entry_map_.end(); ++it) {
if (selected_args.count(it->second) == 0) {
it->second->SetDefault(head);
}
}
}
/*!
* \brief internal function to add entry to manager,
* The manager will take ownership of the entry.
* \param key the key to the parameters
* \param e the pointer to the new entry.
*/
inline void AddEntry(const std::string &key, FieldAccessEntry *e) {
e->index_ = entry_.size();
// TODO(bing) better error message
if (entry_map_.count(key) != 0) {
LOG(FATAL) << "key " << key << " has already been registered in " << name_;
}
entry_.push_back(e);
entry_map_[key] = e;
}
/*!
* \brief internal function to add entry to manager,
* The manager will take ownership of the entry.
* \param key the key to the parameters
* \param e the pointer to the new entry.
*/
inline void AddAlias(const std::string& field, const std::string& alias) {
if (entry_map_.count(field) == 0) {
LOG(FATAL) << "key " << field << " has not been registered in " << name_;
}
if (entry_map_.count(alias) != 0) {
LOG(FATAL) << "Alias " << alias << " has already been registered in " << name_;
}
entry_map_[alias] = entry_map_[field];
}
/*!
* \brief set the name of parameter manager
* \param name the name to set
*/
inline void set_name(const std::string &name) {
name_ = name;
}
/*!
* \brief get field information of each field.
* \return field information
*/
inline std::vector<ParamFieldInfo> GetFieldInfo() const {
std::vector<ParamFieldInfo> ret(entry_.size());
for (size_t i = 0; i < entry_.size(); ++i) {
ret[i] = entry_[i]->GetFieldInfo();
}
return ret;
}
/*!
* \brief Print readible docstring to ostream, add newline.
* \parma os the stream to print the docstring to.
*/
inline void PrintDocString(std::ostream &os) const { // NOLINT(*)
for (size_t i = 0; i < entry_.size(); ++i) {
ParamFieldInfo info = entry_[i]->GetFieldInfo();
os << info.name << " : " << info.type_info_str << '\n';
if (info.description.length() != 0) {
os << " " << info.description << '\n';
}
}
}
/*!
* \brief Get internal parameters in vector of pairs.
* \param head the head of the struct.
* \param skip_default skip the values that equals default value.
* \return the parameter dictionary.
*/
inline std::vector<std::pair<std::string, std::string> > GetDict(void * head) const {
std::vector<std::pair<std::string, std::string> > ret;
for (std::map<std::string, FieldAccessEntry*>::const_iterator
it = entry_map_.begin(); it != entry_map_.end(); ++it) {
ret.push_back(std::make_pair(it->first, it->second->GetStringValue(head)));
}
return ret;
}
private:
/*! \brief parameter struct name */
std::string name_;
/*! \brief positional list of entries */
std::vector<FieldAccessEntry*> entry_;
/*! \brief map from key to entry */
std::map<std::string, FieldAccessEntry*> entry_map_;
};
//! \cond Doxygen_Suppress
// The following piece of code will be template heavy and less documented
// singleton parameter manager for certain type, used for initialization
template<typename PType>
struct ParamManagerSingleton {
ParamManager manager;
explicit ParamManagerSingleton(const std::string &param_name) {
PType param;
param.__DECLARE__(this);
manager.set_name(param_name);
}
};
// Base class of FieldEntry
// implement set_default
template<typename TEntry, typename DType>
class FieldEntryBase : public FieldAccessEntry {
public:
// entry type
typedef TEntry EntryType;
// implement set value
virtual void Set(void *head, const std::string &value) const {
std::istringstream is(value);
is >> this->Get(head);
if (!is.fail()) {
while (!is.eof()) {
int ch = is.get();
if (ch == EOF) {
is.clear(); break;
}
if (!isspace(ch)) {
is.setstate(std::ios::failbit); break;
}
}
}
if (is.fail()) {
std::ostringstream os;
os << "Invalid Parameter format for " << key_
<< " expect " << type_ << " but value=\'" << value<< '\'';
throw dmlc::ParamError(os.str());
}
}
virtual std::string GetStringValue(void *head) const {
std::ostringstream os;
PrintValue(os, this->Get(head));
return os.str();
}
virtual ParamFieldInfo GetFieldInfo() const {
ParamFieldInfo info;
std::ostringstream os;
info.name = key_;
info.type = type_;
os << type_;
if (has_default_) {
os << ',' << " optional, default=";
PrintDefaultValueString(os);
} else {
os << ", required";
}
info.type_info_str = os.str();
info.description = description_;
return info;
}
// implement set head to default value
virtual void SetDefault(void *head) const {
if (!has_default_) {
std::ostringstream os;
os << "Required parameter " << key_
<< " of " << type_ << " is not presented";
throw dmlc::ParamError(os.str());
} else {
this->Get(head) = default_value_;
}
}
// return reference of self as derived type
inline TEntry &self() {
return *(static_cast<TEntry*>(this));
}
// implement set_default
inline TEntry &set_default(const DType &default_value) {
default_value_ = default_value;
has_default_ = true;
// return self to allow chaining
return this->self();
}
// implement describe
inline TEntry &describe(const std::string &description) {
description_ = description;
// return self to allow chaining
return this->self();
}
// initialization function
inline void Init(const std::string &key,
void *head, DType &ref) { // NOLINT(*)
this->key_ = key;
if (this->type_.length() == 0) {
this->type_ = dmlc::type_name<DType>();
}
this->offset_ = ((char*)&ref) - ((char*)head); // NOLINT(*)
}
protected:
// print the value
virtual void PrintValue(std::ostream &os, DType value) const { // NOLINT(*)
os << value;
}
virtual void PrintDefaultValueString(std::ostream &os) const { // NOLINT(*)
PrintValue(os, default_value_);
}
// get the internal representation of parameter
// for example if this entry corresponds field param.learning_rate
// then Get(&param) will return reference to param.learning_rate
inline DType &Get(void *head) const {
return *(DType*)((char*)(head) + offset_); // NOLINT(*)
}
// internal offset of the field
ptrdiff_t offset_;
// default value of field
DType default_value_;
};
// parameter base for numeric types that have range
template<typename TEntry, typename DType>
class FieldEntryNumeric
: public FieldEntryBase<TEntry, DType> {
public:
FieldEntryNumeric()
: has_begin_(false), has_end_(false) {}
// implement set_range
virtual TEntry &set_range(DType begin, DType end) {
begin_ = begin; end_ = end;
has_begin_ = true; has_end_ = true;
return this->self();
}
// implement set_range
virtual TEntry &set_lower_bound(DType begin) {
begin_ = begin; has_begin_ = true;
return this->self();
}
// consistency check for numeric ranges
virtual void Check(void *head) const {
FieldEntryBase<TEntry, DType>::Check(head);
DType v = this->Get(head);
if (has_begin_ && has_end_) {
if (v < begin_ || v > end_) {
std::ostringstream os;
os << "value " << v << " for Parameter " << this->key_
<< " exceed bound [" << begin_ << ',' << end_ <<']';
throw dmlc::ParamError(os.str());
}
} else if (has_begin_ && v < begin_) {
std::ostringstream os;
os << "value " << v << " for Parameter " << this->key_
<< " should be greater equal to " << begin_;
throw dmlc::ParamError(os.str());
} else if (has_end_ && v > end_) {
std::ostringstream os;
os << "value " << v << " for Parameter " << this->key_
<< " should be smaller equal to " << end_;
throw dmlc::ParamError(os.str());
}
}
protected:
// whether it have begin and end range
bool has_begin_, has_end_;
// data bound
DType begin_, end_;
};
/*!
* \brief FieldEntry defines parsing and checking behavior of DType.
* This class can be specialized to implement specific behavior of more settings.
* \tparam DType the data type of the entry.
*/
template<typename DType>
class FieldEntry :
public IfThenElseType<dmlc::is_arithmetic<DType>::value,
FieldEntryNumeric<FieldEntry<DType>, DType>,
FieldEntryBase<FieldEntry<DType>, DType> >::Type {
};
// specialize define for int(enum)
template<>
class FieldEntry<int>
: public FieldEntryNumeric<FieldEntry<int>, int> {
public:
// construct
FieldEntry<int>() : is_enum_(false) {}
// parent
typedef FieldEntryNumeric<FieldEntry<int>, int> Parent;
// override set
virtual void Set(void *head, const std::string &value) const {
if (is_enum_) {
std::map<std::string, int>::const_iterator it = enum_map_.find(value);
std::ostringstream os;
if (it == enum_map_.end()) {
os << "Invalid Input: \'" << value;
os << "\', valid values are: ";
PrintEnums(os);
throw dmlc::ParamError(os.str());
} else {
os << it->second;
Parent::Set(head, os.str());
}
} else {
Parent::Set(head, value);
}
}
virtual ParamFieldInfo GetFieldInfo() const {
if (is_enum_) {
ParamFieldInfo info;
std::ostringstream os;
info.name = key_;
info.type = type_;
PrintEnums(os);
if (has_default_) {
os << ',' << "optional, default=";
PrintDefaultValueString(os);
} else {
os << ", required";
}
info.type_info_str = os.str();
info.description = description_;
return info;
} else {
return Parent::GetFieldInfo();
}
}
// add enum
inline FieldEntry<int> &add_enum(const std::string &key, int value) {
if ((enum_map_.size() != 0 && enum_map_.count(key) != 0) || \
enum_back_map_.count(value) != 0) {
std::ostringstream os;
os << "Enum " << "(" << key << ": " << value << " exisit!" << ")\n";
os << "Enums: ";
for (std::map<std::string, int>::const_iterator it = enum_map_.begin();
it != enum_map_.end(); ++it) {
os << "(" << it->first << ": " << it->second << "), ";
}
throw dmlc::ParamError(os.str());
}
enum_map_[key] = value;
enum_back_map_[value] = key;
is_enum_ = true;
return this->self();
}
protected:
// enum flag
bool is_enum_;
// enum map
std::map<std::string, int> enum_map_;
// enum map
std::map<int, std::string> enum_back_map_;
// override print behavior
virtual void PrintDefaultValueString(std::ostream &os) const { // NOLINT(*)
os << '\'';
PrintValue(os, default_value_);
os << '\'';
}
// override print default
virtual void PrintValue(std::ostream &os, int value) const { // NOLINT(*)
if (is_enum_) {
CHECK_NE(enum_back_map_.count(value), 0U)
<< "Value not found in enum declared";
os << enum_back_map_.at(value);
} else {
os << value;
}
}
private:
inline void PrintEnums(std::ostream &os) const { // NOLINT(*)
os << '{';
for (std::map<std::string, int>::const_iterator
it = enum_map_.begin(); it != enum_map_.end(); ++it) {
if (it != enum_map_.begin()) {
os << ", ";
}
os << "\'" << it->first << '\'';
}
os << '}';
}
};
// specialize define for optional<int>(enum)
template<>
class FieldEntry<optional<int> >
: public FieldEntryBase<FieldEntry<optional<int> >, optional<int> > {
public:
// construct
FieldEntry<optional<int> >() : is_enum_(false) {}
// parent
typedef FieldEntryBase<FieldEntry<optional<int> >, optional<int> > Parent;
// override set
virtual void Set(void *head, const std::string &value) const {
if (is_enum_ && value != "None") {
std::map<std::string, int>::const_iterator it = enum_map_.find(value);
std::ostringstream os;
if (it == enum_map_.end()) {
os << "Invalid Input: \'" << value;
os << "\', valid values are: ";
PrintEnums(os);
throw dmlc::ParamError(os.str());
} else {
os << it->second;
Parent::Set(head, os.str());
}
} else {
Parent::Set(head, value);
}
}
virtual ParamFieldInfo GetFieldInfo() const {
if (is_enum_) {
ParamFieldInfo info;
std::ostringstream os;
info.name = key_;
info.type = type_;
PrintEnums(os);
if (has_default_) {
os << ',' << "optional, default=";
PrintDefaultValueString(os);
} else {
os << ", required";
}
info.type_info_str = os.str();
info.description = description_;
return info;
} else {
return Parent::GetFieldInfo();
}
}
// add enum
inline FieldEntry<optional<int> > &add_enum(const std::string &key, int value) {
CHECK_NE(key, "None") << "None is reserved for empty optional<int>";
if ((enum_map_.size() != 0 && enum_map_.count(key) != 0) || \
enum_back_map_.count(value) != 0) {
std::ostringstream os;
os << "Enum " << "(" << key << ": " << value << " exisit!" << ")\n";
os << "Enums: ";
for (std::map<std::string, int>::const_iterator it = enum_map_.begin();
it != enum_map_.end(); ++it) {
os << "(" << it->first << ": " << it->second << "), ";
}
throw dmlc::ParamError(os.str());
}
enum_map_[key] = value;
enum_back_map_[value] = key;
is_enum_ = true;
return this->self();
}
protected:
// enum flag
bool is_enum_;
// enum map
std::map<std::string, int> enum_map_;
// enum map
std::map<int, std::string> enum_back_map_;
// override print behavior
virtual void PrintDefaultValueString(std::ostream &os) const { // NOLINT(*)
os << '\'';
PrintValue(os, default_value_);
os << '\'';
}
// override print default
virtual void PrintValue(std::ostream &os, optional<int> value) const { // NOLINT(*)
if (is_enum_) {
if (!value) {
os << "None";
} else {
CHECK_NE(enum_back_map_.count(value.value()), 0U)
<< "Value not found in enum declared";
os << enum_back_map_.at(value.value());
}
} else {
os << value;
}
}
private:
inline void PrintEnums(std::ostream &os) const { // NOLINT(*)
os << "{None";
for (std::map<std::string, int>::const_iterator
it = enum_map_.begin(); it != enum_map_.end(); ++it) {
os << ", ";
os << "\'" << it->first << '\'';
}
os << '}';
}
};
// specialize define for string
template<>
class FieldEntry<std::string>
: public FieldEntryBase<FieldEntry<std::string>, std::string> {
public:
// parent class
typedef FieldEntryBase<FieldEntry<std::string>, std::string> Parent;
// override set
virtual void Set(void *head, const std::string &value) const {
this->Get(head) = value;
}
// override print default
virtual void PrintDefaultValueString(std::ostream &os) const { // NOLINT(*)
os << '\'' << default_value_ << '\'';
}
};
// specialize define for bool
template<>
class FieldEntry<bool>
: public FieldEntryBase<FieldEntry<bool>, bool> {
public:
// parent class
typedef FieldEntryBase<FieldEntry<bool>, bool> Parent;
// override set
virtual void Set(void *head, const std::string &value) const {
std::string lower_case; lower_case.resize(value.length());
std::transform(value.begin(), value.end(), lower_case.begin(), ::tolower);
bool &ref = this->Get(head);
if (lower_case == "true") {
ref = true;
} else if (lower_case == "false") {
ref = false;
} else if (lower_case == "1") {
ref = true;
} else if (lower_case == "0") {
ref = false;
} else {
std::ostringstream os;
os << "Invalid Parameter format for " << key_
<< " expect " << type_ << " but value=\'" << value<< '\'';
throw dmlc::ParamError(os.str());
}
}
protected:
// print default string
virtual void PrintValue(std::ostream &os, bool value) const { // NOLINT(*)
if (value) {
os << "True";
} else {
os << "False";
}
}
};
} // namespace parameter
//! \endcond
// implement GetEnv
template<typename ValueType>
inline ValueType GetEnv(const char *key,
ValueType default_value) {
const char *val = getenv(key);
if (val == NULL) return default_value;
ValueType ret;
parameter::FieldEntry<ValueType> e;
e.Init(key, &ret, ret);
e.Set(&ret, val);
return ret;
}
} // namespace dmlc
#endif // DMLC_PARAMETER_H_
//===== EXPANDED: ../dmlc-core/include/dmlc/parameter.h =====
namespace dmlc {
/*!
* \brief Registry class.
* Registry can be used to register global singletons.
* The most commonly use case are factory functions.
*
* \tparam EntryType Type of Registry entries,
* EntryType need to name a name field.
*/
template<typename EntryType>
class Registry {
public:
/*! \return list of entries in the registry(excluding alias) */
inline static const std::vector<const EntryType*>& List() {
return Get()->const_list_;
}
/*! \return list all names registered in the registry, including alias */
inline static std::vector<std::string> ListAllNames() {
const std::map<std::string, EntryType*> &fmap = Get()->fmap_;
typename std::map<std::string, EntryType*>::const_iterator p;
std::vector<std::string> names;
for (p = fmap.begin(); p !=fmap.end(); ++p) {
names.push_back(p->first);
}
return names;
}
/*!
* \brief Find the entry with corresponding name.
* \param name name of the function
* \return the corresponding function, can be NULL
*/
inline static const EntryType *Find(const std::string &name) {
const std::map<std::string, EntryType*> &fmap = Get()->fmap_;
typename std::map<std::string, EntryType*>::const_iterator p = fmap.find(name);
if (p != fmap.end()) {
return p->second;
} else {
return NULL;
}
}
/*!
* \brief Add alias to the key_name
* \param key_name The original entry key
* \param alias The alias key.
*/
inline void AddAlias(const std::string& key_name,
const std::string& alias) {
EntryType* e = fmap_.at(key_name);
if (fmap_.count(alias)) {
CHECK_EQ(e, fmap_.at(alias))
<< "Entry " << e->name << " already registered under different entry";
} else {
fmap_[alias] = e;
}
}
/*!
* \brief Internal function to register a name function under name.
* \param name name of the function
* \return ref to the registered entry, used to set properties
*/
inline EntryType &__REGISTER__(const std::string& name) {
CHECK_EQ(fmap_.count(name), 0U)
<< name << " already registered";
EntryType *e = new EntryType();
e->name = name;
fmap_[name] = e;
const_list_.push_back(e);
entry_list_.push_back(e);
return *e;
}
/*!
* \brief Internal function to either register or get registered entry
* \param name name of the function
* \return ref to the registered entry, used to set properties
*/
inline EntryType &__REGISTER_OR_GET__(const std::string& name) {
if (fmap_.count(name) == 0) {
return __REGISTER__(name);
} else {
return *fmap_.at(name);
}
}
/*!
* \brief get a singleton of the Registry.
* This function can be defined by DMLC_ENABLE_REGISTRY.
* \return get a singleton
*/
static Registry *Get();
private:
/*! \brief list of entry types */
std::vector<EntryType*> entry_list_;
/*! \brief list of entry types */
std::vector<const EntryType*> const_list_;
/*! \brief map of name->function */
std::map<std::string, EntryType*> fmap_;
/*! \brief constructor */
Registry() {}
/*! \brief destructor */
~Registry() {
for (size_t i = 0; i < entry_list_.size(); ++i) {
delete entry_list_[i];
}
}
};
/*!
* \brief Common base class for function registry.
*
* \code
* // This example demonstrates how to use Registry to create a factory of trees.
* struct TreeFactory :
* public FunctionRegEntryBase<TreeFactory, std::function<Tree*()> > {
* };
*
* // in a independent cc file
* namespace dmlc {
* DMLC_REGISTRY_ENABLE(TreeFactory);
* }
* // register binary tree constructor into the registry.
* DMLC_REGISTRY_REGISTER(TreeFactory, TreeFactory, BinaryTree)
* .describe("Constructor of BinaryTree")
* .set_body([]() { return new BinaryTree(); });
* \endcode
*
* \tparam EntryType The type of subclass that inheritate the base.
* \tparam FunctionType The function type this registry is registerd.
*/
template<typename EntryType, typename FunctionType>
class FunctionRegEntryBase {
public:
/*! \brief name of the entry */
std::string name;
/*! \brief description of the entry */
std::string description;
/*! \brief additional arguments to the factory function */
std::vector<ParamFieldInfo> arguments;
/*! \brief Function body to create ProductType */
FunctionType body;
/*! \brief Return type of the function */
std::string return_type;
/*!
* \brief Set the function body.
* \param body Function body to set.
* \return reference to self.
*/
inline EntryType &set_body(FunctionType body) {
this->body = body;
return this->self();
}
/*!
* \brief Describe the function.
* \param description The description of the factory function.
* \return reference to self.
*/
inline EntryType &describe(const std::string &description) {
this->description = description;
return this->self();
}
/*!
* \brief Add argument information to the function.
* \param name Name of the argument.
* \param type Type of the argument.
* \param description Description of the argument.
* \return reference to self.
*/
inline EntryType &add_argument(const std::string &name,
const std::string &type,
const std::string &description) {
ParamFieldInfo info;
info.name = name;
info.type = type;
info.type_info_str = info.type;
info.description = description;
arguments.push_back(info);
return this->self();
}
/*!
* \brief Append list if arguments to the end.
* \param args Additional list of arguments.
* \return reference to self.
*/
inline EntryType &add_arguments(const std::vector<ParamFieldInfo> &args) {
arguments.insert(arguments.end(), args.begin(), args.end());
return this->self();
}
/*!
* \brief Set the return type.
* \param type Return type of the function, could be Symbol or Symbol[]
* \return reference to self.
*/
inline EntryType &set_return_type(const std::string &type) {
return_type = type;
return this->self();
}
protected:
/*!
* \return reference of self as derived type
*/
inline EntryType &self() {
return *(static_cast<EntryType*>(this));
}
};
/*!
* \def DMLC_REGISTRY_ENABLE
* \brief Macro to enable the registry of EntryType.
* This macro must be used under namespace dmlc, and only used once in cc file.
* \param EntryType Type of registry entry
*/
#define DMLC_REGISTRY_ENABLE(EntryType) \
template<> \
Registry<EntryType > *Registry<EntryType >::Get() { \
static Registry<EntryType > inst; \
return &inst; \
} \
/*!
* \brief Generic macro to register an EntryType
* There is a complete example in FactoryRegistryEntryBase.
*
* \param EntryType The type of registry entry.
* \param EntryTypeName The typename of EntryType, must do not contain namespace :: .
* \param Name The name to be registered.
* \sa FactoryRegistryEntryBase
*/
#define DMLC_REGISTRY_REGISTER(EntryType, EntryTypeName, Name) \
static DMLC_ATTRIBUTE_UNUSED EntryType & __make_ ## EntryTypeName ## _ ## Name ## __ = \
::dmlc::Registry<EntryType>::Get()->__REGISTER__(#Name) \
/*!
* \brief (Optional) Declare a file tag to current file that contains object registrations.
*
* This will declare a dummy function that will be called by register file to
* incur a link dependency.
*
* \param UniqueTag The unique tag used to represent.
* \sa DMLC_REGISTRY_LINK_TAG
*/
#define DMLC_REGISTRY_FILE_TAG(UniqueTag) \
int __dmlc_registry_file_tag_ ## UniqueTag ## __() { return 0; }
/*!
* \brief (Optional) Force link to all the objects registered in file tag.
*
* This macro must be used in the same file as DMLC_REGISTRY_ENABLE and
* in the same namespace as DMLC_REGISTRY_FILE_TAG
*
* DMLC_REGISTRY_FILE_TAG and DMLC_REGISTRY_LINK_TAG are optional macros for registration.
* They are used to encforce link of certain file into during static linking.
*
* This is mainly used to solve problem during statically link a library which contains backward registration.
* Specifically, this avoids the objects in these file tags to be ignored by compiler.
*
* For dynamic linking, this problem won't occur as everything is loaded by default.
*
* Use of this is optional as it will create an error when a file tag do not exist.
* An alternative solution is always ask user to enable --whole-archieve during static link.
*
* \begincode
* // in file objective_registry.cc
* DMLC_REGISTRY_ENABLE(MyObjective);
* DMLC_REGISTRY_LINK_TAG(regression_op);
* DMLC_REGISTRY_LINK_TAG(rank_op);
*
* // in file regression_op.cc
* // declare tag of this file.
* DMLC_REGISTRY_FILE_TAG(regression_op);
* DMLC_REGISTRY_REGISTER(MyObjective, logistic_reg, logistic_reg);
* // ...
*
* // in file rank_op.cc
* // declare tag of this file.
* DMLC_REGISTRY_FILE_TAG(rank_op);
* DMLC_REGISTRY_REGISTER(MyObjective, pairwiserank, pairwiserank);
*
* \endcode
*
* \param UniqueTag The unique tag used to represent.
* \sa DMLC_REGISTRY_ENABLE, DMLC_REGISTRY_FILE_TAG
*/
#define DMLC_REGISTRY_LINK_TAG(UniqueTag) \
int __dmlc_registry_file_tag_ ## UniqueTag ## __(); \
static int DMLC_ATTRIBUTE_UNUSED __reg_file_tag_ ## UniqueTag ## __ = \
__dmlc_registry_file_tag_ ## UniqueTag ## __();
} // namespace dmlc
#endif // DMLC_REGISTRY_H_
//===== EXPANDED: ../dmlc-core/include/dmlc/registry.h =====
namespace dmlc {
/*!
* \brief this defines the float point
* that will be used to store feature values
*/
typedef float real_t;
/*!
* \brief this defines the unsigned integer type
* that can normally be used to store feature index
*/
typedef unsigned index_t;
// This file describes common data structure that can be used
// for large-scale machine learning, this may not be a complete list
// But we will keep the most common and useful ones, and keep adding new ones
/*!
* \brief data iterator interface
* this is not a C++ style iterator, but nice for data pulling:)
* This interface is used to pull in the data
* The system can do some useful tricks for you like pre-fetching
* from disk and pre-computation.
*
* Usage example:
* \code
*
* itr->BeforeFirst();
* while (itr->Next()) {
* const DType &batch = itr->Value();
* // some computations
* }
* \endcode
* \tparam DType the data type
*/
template<typename DType>
class DataIter {
public:
/*! \brief destructor */
virtual ~DataIter(void) {}
/*! \brief set before first of the item */
virtual void BeforeFirst(void) = 0;
/*! \brief move to next item */
virtual bool Next(void) = 0;
/*! \brief get current data */
virtual const DType &Value(void) const = 0;
};
/*!
* \brief one row of training instance
* \tparam IndexType type of index
*/
template<typename IndexType>
class Row {
public:
/*! \brief label of the instance */
real_t label;
/*! \brief weight of the instance */
real_t weight;
/*! \brief length of the sparse vector */
size_t length;
/*!
* \brief index of each instance
*/
const IndexType *index;
/*!
* \brief array value of each instance, this can be NULL
* indicating every value is set to be 1
*/
const real_t *value;
/*!
* \param i the input index
* \return i-th feature
*/
inline IndexType get_index(size_t i) const {
return index[i];
}
/*!
* \param i the input index
* \return i-th feature value, this function is always
* safe even when value == NULL
*/
inline real_t get_value(size_t i) const {
return value == NULL ? 1.0f : value[i];
}
/*!
* \brief helper function to compute dot product of current
* \param weight the dense array of weight we want to product
* \param size the size of the weight vector
* \tparam V type of the weight vector
* \return the result of dot product
*/
template<typename V>
inline V SDot(const V *weight, size_t size) const {
V sum = static_cast<V>(0);
if (value == NULL) {
for (size_t i = 0; i < length; ++i) {
CHECK(index[i] < size) << "feature index exceed bound";
sum += weight[index[i]];
}
} else {
for (size_t i = 0; i < length; ++i) {
CHECK(index[i] < size) << "feature index exceed bound";
sum += weight[index[i]] * value[i];
}
}
return sum;
}
};
/*!
* \brief a block of data, containing several rows in sparse matrix
* This is useful for (streaming-sxtyle) algorithms that scans through rows of data
* examples include: SGD, GD, L-BFGS, kmeans
*
* The size of batch is usually large enough so that parallelizing over the rows
* can give significant speedup
* \tparam IndexType type to store the index used in row batch
*/
template<typename IndexType>
struct RowBlock {
/*! \brief batch size */
size_t size;
/*! \brief array[size+1], row pointer to beginning of each rows */
const size_t *offset;
/*! \brief array[size] label of each instance */
const real_t *label;
/*! \brief With weight: array[size] label of each instance, otherwise nullptr */
const real_t *weight;
/*! \brief feature index */
const IndexType *index;
/*! \brief feature value, can be NULL, indicating all values are 1 */
const real_t *value;
/*!
* \brief get specific rows in the batch
* \param rowid the rowid in that row
* \return the instance corresponding to the row
*/
inline Row<IndexType> operator[](size_t rowid) const;
/*! \return memory cost of the block in bytes */
inline size_t MemCostBytes(void) const {
size_t cost = size * (sizeof(size_t) + sizeof(real_t));
if (weight != NULL) cost += size * sizeof(real_t);
size_t ndata = offset[size] - offset[0];
if (index != NULL) cost += ndata * sizeof(IndexType);
if (value != NULL) cost += ndata * sizeof(real_t);
return cost;
}
/*!
* \brief slice a RowBlock to get rows in [begin, end)
* \param begin the begin row index
* \param end the end row index
* \return the sliced RowBlock
*/
inline RowBlock Slice(size_t begin, size_t end) const {
CHECK(begin <= end && end <= size);
RowBlock ret;
ret.size = end - begin;
ret.label = label + begin;
if (weight != NULL) {
ret.weight = weight + begin;
} else {
ret.weight = NULL;
}
ret.offset = offset + begin;
ret.index = index;
ret.value = value;
return ret;
}
};
/*!
* \brief Data structure that holds the data
* Row block iterator interface that gets RowBlocks
* Difference between RowBlockIter and Parser:
* RowBlockIter caches the data internally that can be used
* to iterate the dataset multiple times,
* Parser holds very limited internal state and was usually
* used to read data only once
*
* \sa Parser
* \tparam IndexType type of index in RowBlock
* Create function was only implemented for IndexType uint64_t and uint32_t
*/
template<typename IndexType>
class RowBlockIter : public DataIter<RowBlock<IndexType> > {
public:
/*!
* \brief create a new instance of iterator that returns rowbatch
* by default, a in-memory based iterator will be returned
*
* \param uri the uri of the input, can contain hdfs prefix
* \param part_index the part id of current input
* \param num_parts total number of splits
* \param type type of dataset can be: "libsvm", ...
*
* \return the created data iterator
*/
static RowBlockIter<IndexType> *
Create(const char *uri,
unsigned part_index,
unsigned num_parts,
const char *type);
/*! \return maximum feature dimension in the dataset */
virtual size_t NumCol() const = 0;
};
/*!
* \brief parser interface that parses input data
* used to load dmlc data format into your own data format
* Difference between RowBlockIter and Parser:
* RowBlockIter caches the data internally that can be used
* to iterate the dataset multiple times,
* Parser holds very limited internal state and was usually
* used to read data only once
*
*
* \sa RowBlockIter
* \tparam IndexType type of index in RowBlock
* Create function was only implemented for IndexType uint64_t and uint32_t
*/
template <typename IndexType>
class Parser : public DataIter<RowBlock<IndexType> > {
public:
/*!
* \brief create a new instance of parser based on the "type"
*
* \param uri_ the uri of the input, can contain hdfs prefix
* \param part_index the part id of current input
* \param num_parts total number of splits
* \param type type of dataset can be: "libsvm", "auto", ...
*
* When "auto" is passed, the type is decided by format argument string in URI.
*
* \return the created parser
*/
static Parser<IndexType> *
Create(const char *uri_,
unsigned part_index,
unsigned num_parts,
const char *type);
/*! \return size of bytes read so far */
virtual size_t BytesRead(void) const = 0;
/*! \brief Factory type of the parser*/
typedef Parser<IndexType>* (*Factory)
(const std::string& path,
const std::map<std::string, std::string>& args,
unsigned part_index,
unsigned num_parts);
};
/*!
* \brief registry entry of parser factory
* \tparam IndexType The type of index
*/
template<typename IndexType>
struct ParserFactoryReg
: public FunctionRegEntryBase<ParserFactoryReg<IndexType>,
typename Parser<IndexType>::Factory> {};
/*!
* \brief Register a new distributed parser to dmlc-core.
*
* \param IndexType The type of Batch index, can be uint32_t or uint64_t
* \param TypeName The typename of of the data.
* \param FactoryFunction The factory function that creates the parser.
*
* \begincode
*
* // defin the factory function
* template<typename IndexType>
* Parser<IndexType>*
* CreateLibSVMParser(const char* uri, unsigned part_index, unsigned num_parts) {
* return new LibSVMParser(uri, part_index, num_parts);
* }
*
* // Register it to DMLC
* // Then we can use Parser<uint32_t>::Create(uri, part_index, num_parts, "libsvm");
* // to create the parser
*
* DMLC_REGISTER_DATA_PARSER(uint32_t, libsvm, CreateLibSVMParser<uint32_t>);
* DMLC_REGISTER_DATA_PARSER(uint64_t, libsvm, CreateLibSVMParser<uint64_t>);
*
* \endcode
*/
#define DMLC_REGISTER_DATA_PARSER(IndexType, TypeName, FactoryFunction) \
DMLC_REGISTRY_REGISTER(::dmlc::ParserFactoryReg<IndexType>, \
ParserFactoryReg ## _ ## IndexType, TypeName) \
.set_body(FactoryFunction)
// implementation of operator[]
template<typename IndexType>
inline Row<IndexType>
RowBlock<IndexType>::operator[](size_t rowid) const {
CHECK(rowid < size);
Row<IndexType> inst;
inst.label = label[rowid];
if (weight != NULL) {
inst.weight = weight[rowid];
} else {
inst.weight = 1.0f;
}
inst.length = offset[rowid + 1] - offset[rowid];
inst.index = index + offset[rowid];
if (value == NULL) {
inst.value = NULL;
} else {
inst.value = value + offset[rowid];
}
return inst;
}
} // namespace dmlc
#endif // DMLC_DATA_H_
//===== EXPANDED: ../dmlc-core/include/dmlc/data.h =====
//===== EXPANDING: ../dmlc-core/src/io/uri_spec.h =====
/*!
* Copyright (c) 2015 by Contributors
* \file uri_spec.h
* \brief common specification of sugars in URI
* string passed to dmlc Create functions
* such as local file cache
* \author Tianqi Chen
*/
#ifndef DMLC_IO_URI_SPEC_H_
#define DMLC_IO_URI_SPEC_H_
namespace dmlc {
namespace io {
/*!
* \brief some super set of URI
* that allows sugars to be passed around
* Example:
*
* hdfs:///mylibsvm/?format=libsvm&clabel=0#mycache-file.
*/
class URISpec {
public:
/*! \brief the real URI */
std::string uri;
/*! \brief arguments in the URL */
std::map<std::string, std::string> args;
/*! \brief the path to cache file */
std::string cache_file;
/*!
* \brief constructor.
* \param uri The raw uri string.
* \param part_index The parition index of the part.
* \param num_parts total number of parts.
*/
explicit URISpec(const std::string& uri,
unsigned part_index,
unsigned num_parts) {
std::vector<std::string> name_cache = Split(uri, '#');
if (name_cache.size() == 2) {
std::ostringstream os;
os << name_cache[1];
if (num_parts != 1) {
os << ".split" << num_parts << ".part" << part_index;
}
this->cache_file = os.str();
} else {
CHECK_EQ(name_cache.size(), 1U)
<< "only one `#` is allowed in file path for cachefile specification";
}
std::vector<std::string> name_args = Split(name_cache[0], '?');
if (name_args.size() == 2) {
std::vector<std::string> arg_list = Split(name_args[1], '&');
for (size_t i = 0; i < arg_list.size(); ++i) {
std::istringstream is(arg_list[i]);
std::pair<std::string, std::string> kv;
CHECK(std::getline(is, kv.first, '=')) << "Invalid uri argument format"
<< " for key in arg " << i + 1;
CHECK(std::getline(is, kv.second)) << "Invalid uri argument format"
<< " for value in arg " << i + 1;
this->args.insert(kv);
}
} else {
CHECK_EQ(name_args.size(), 1U)
<< "only one `#` is allowed in file path for cachefile specification";
}
this->uri = name_args[0];
}
};
} // namespace io
} // namespace dmlc
#endif // DMLC_IO_URI_SPEC_H_
//===== EXPANDED: ../dmlc-core/src/io/uri_spec.h =====
//===== EXPANDING: ../dmlc-core/src/data/parser.h =====
/*!
* Copyright (c) 2015 by Contributors
* \file libsvm_parser.h
* \brief iterator parser to parse libsvm format
* \author Tianqi Chen
*/
#ifndef DMLC_DATA_PARSER_H_
#define DMLC_DATA_PARSER_H_
//===== EXPANDING: ../dmlc-core/include/dmlc/threadediter.h =====
/*!
* Copyright (c) 2015 by Contributors
* \file threadediter.h
* \brief thread backed iterator that can be used to implement
* general thread-based pipeline such as prefetch and pre-computation
* To use the functions in this header, C++11 is required
* \author Tianqi Chen
*/
#ifndef DMLC_THREADEDITER_H_
#define DMLC_THREADEDITER_H_
// defines DMLC_USE_CXX11
// this code depends on c++11
#if DMLC_ENABLE_STD_THREAD
namespace dmlc {
/*!
* \brief a iterator that was backed by a thread
* to pull data eagerly from a single producer into a bounded buffer
* the consumer can pull the data at its own rate
*
* NOTE: thread concurrency cost time, make sure to store big blob of data in DType
*
* Usage example:
* \code
* ThreadedIter<DType> iter;
* iter.Init(&producer);
* // the following code can be in parallel
* DType *dptr;
* while (iter.Next(&dptr)) {
* // do something on dptr
* // recycle the space
* iter.Recycle(&dptr);
* }
* \endcode
* \tparam DType the type of data blob we support
*/
template<typename DType>
class ThreadedIter : public DataIter<DType> {
public:
/*!
* \brief producer class interface
* that threaditer used as source to
* preduce the content
*/
class Producer {
public:
// virtual destructor
virtual ~Producer() {}
/*! \brief reset the producer to beginning */
virtual void BeforeFirst(void) {
NotImplemented();
}
/*!
* \brief load the data content into DType,
* the caller can pass in NULL or an existing address
* when inout_dptr is NULL:
* producer need to allocate a DType and fill the content
* when inout_dptr is specified
* producer takes need to fill the content into address
* specified inout_dptr, or delete the one and create a new one
*
* \param inout_dptr used to pass in the data holder cell
* and return the address of the cell filled
* \return true if there is next record, false if we reach the end
*/
virtual bool Next(DType **inout_dptr) = 0;
};
/*!
* \brief constructor
* \param max_capacity maximum capacity of the queue
*/
explicit ThreadedIter(size_t max_capacity = 8)
: producer_owned_(NULL),
producer_thread_(NULL),
max_capacity_(max_capacity),
nwait_consumer_(0),
nwait_producer_(0),
out_data_(NULL) {}
/*! \brief destructor */
virtual ~ThreadedIter(void) {
this->Destroy();
}
/*!
* \brief destroy all the related resources
* this is equivalent to destructor, can be used
* to destroy the threaditer when user think it is
* appropriate, it is safe to call this multiple times
*/
inline void Destroy(void);
/*!
* \brief set maximum capacity of the queue
* \param max_capacity maximum capacity of the queue
*/
inline void set_max_capacity(size_t max_capacity) {
max_capacity_ = max_capacity;
}
/*!
* \brief initialize the producer and start the thread
* can only be called once
* \param producer pointer to the producer
* \param pass_ownership whether pass the ownership to the iter
* if this is true, the threaditer will delete the producer
* when destructed
*/
inline void Init(Producer *producer, bool pass_ownership = false);
/*!
* \brief initialize the producer and start the thread
* pass in two function(closure) of producer to represent the producer
* the beforefirst function is optional, and defaults to not implemented
* NOTE: the closure must remain valid until the ThreadedIter destructs
* \param next the function called to get next element, see Producer.Next
* \param beforefirst the function to call to reset the producer, see Producer.BeforeFirst
*/
inline void Init(std::function<bool(DType **)> next,
std::function<void()> beforefirst = NotImplemented);
/*!
* \brief get the next data, this function is threadsafe
* \param out_dptr used to hold the pointer to the record
* after the function call, the caller takes ownership of the pointer
* the caller can call recycle to return ownership back to the threaditer
* so that the pointer can be re-used
* \return true if there is next record, false if we reach the end
* \sa Recycle
*/
inline bool Next(DType **out_dptr);
/*!
* \brief recycle the data cell, this function is threadsafe
* the threaditer can reuse the data cell for future data loading
* \param inout_dptr pointer to the dptr to recycle, after the function call
* the content of inout_dptr will be set to NULL
*/
inline void Recycle(DType **inout_dptr);
/*!
* \brief adapt the iterator interface's Next
* NOTE: the call to this function is not threadsafe
* use the other Next instead
* \return true if there is next record, false if we reach the end
*/
virtual bool Next(void) {
if (out_data_ != NULL) {
this->Recycle(&out_data_);
}
if (Next(&out_data_)) {
return true;
} else {
return false;
}
}
/*!
* \brief adapt the iterator interface's Value
* NOTE: the call to this function is not threadsafe
* use the other Next instead
*/
virtual const DType &Value(void) const {
CHECK(out_data_ != NULL) << "Calling Value at beginning or end?";
return *out_data_;
}
/*! \brief set the iterator before first location */
virtual void BeforeFirst(void) {
std::unique_lock<std::mutex> lock(mutex_);
if (out_data_ != NULL) {
free_cells_.push(out_data_);
out_data_ = NULL;
}
if (producer_sig_ == kDestroy) return;
producer_sig_ = kBeforeFirst;
CHECK(!producer_sig_processed_);
if (nwait_producer_ != 0) {
producer_cond_.notify_one();
}
CHECK(!producer_sig_processed_);
// wait until the request has been processed
consumer_cond_.wait(lock, [this]() {
return producer_sig_processed_;
});
producer_sig_processed_ = false;
bool notify = nwait_producer_ != 0 && !produce_end_;
lock.unlock();
// notify producer, in case they are waiting for the condition.
if (notify) producer_cond_.notify_one();
}
private:
/*! \brief not support BeforeFirst */
inline static void NotImplemented(void) {
LOG(FATAL) << "BeforeFirst is not supported";
}
/*! \brief signals send to producer */
enum Signal {
kProduce,
kBeforeFirst,
kDestroy
};
/*! \brief producer class */
Producer *producer_owned_;
/*! \brief signal to producer */
Signal producer_sig_;
/*! \brief whether the special signal other than kProduce is procssed */
bool producer_sig_processed_;
/*! \brief thread that runs the producer */
std::thread *producer_thread_;
/*! \brief whether produce ends */
bool produce_end_;
/*! \brief maximum queue size */
size_t max_capacity_;
/*! \brief internal mutex */
std::mutex mutex_;
/*! \brief number of consumer waiting */
unsigned nwait_consumer_;
/*! \brief number of consumer waiting */
unsigned nwait_producer_;
/*! \brief conditional variable for producer thread */
std::condition_variable producer_cond_;
/*! \brief conditional variable for consumer threads */
std::condition_variable consumer_cond_;
/*! \brief the current output cell */
DType *out_data_;
/*! \brief internal queue of producer */
std::queue<DType*> queue_;
/*! \brief free cells that can be used */
std::queue<DType*> free_cells_;
};
// implementation of functions
template<typename DType>
inline void ThreadedIter<DType>::Destroy(void) {
if (producer_thread_ != NULL) {
{
// lock the mutex
std::lock_guard<std::mutex> lock(mutex_);
// send destroy signal
producer_sig_ = kDestroy;
if (nwait_producer_ != 0) {
producer_cond_.notify_one();
}
}
producer_thread_->join();
delete producer_thread_;
producer_thread_ = NULL;
}
// end of critical region
// now the slave thread should exit
while (free_cells_.size() != 0) {
delete free_cells_.front();
free_cells_.pop();
}
while (queue_.size() != 0) {
delete queue_.front();
queue_.pop();
}
if (producer_owned_ != NULL) {
delete producer_owned_;
}
if (out_data_ != NULL) {
delete out_data_; out_data_ = NULL;
}
}
template<typename DType>
inline void ThreadedIter<DType>::
Init(Producer *producer, bool pass_ownership) {
CHECK(producer_owned_ == NULL) << "can only call Init once";
if (pass_ownership) producer_owned_ = producer;
auto next = [producer](DType **dptr) {
return producer->Next(dptr);
};
auto beforefirst = [producer]() {
producer->BeforeFirst();
};
this->Init(next, beforefirst);
}
template<typename DType>
inline void ThreadedIter<DType>::
Init(std::function<bool(DType **)> next,
std::function<void()> beforefirst) {
producer_sig_ = kProduce;
producer_sig_processed_ = false;
produce_end_ = false;
// procedure running in prodcuer
// run producer thread
auto producer_fun = [this, next, beforefirst] () {
while (true) {
DType *cell = NULL;
{
// lockscope
std::unique_lock<std::mutex> lock(mutex_);
++this->nwait_producer_;
producer_cond_.wait(lock, [this]() {
if (producer_sig_ == kProduce) {
bool ret = !produce_end_ &&
(queue_.size() < max_capacity_ || free_cells_.size() != 0);
return ret;
} else {
return true;
}
});
--this->nwait_producer_;
if (producer_sig_ == kProduce) {
if (free_cells_.size() != 0) {
cell = free_cells_.front();
free_cells_.pop();
}
} else if (producer_sig_ == kBeforeFirst) {
// reset the producer
beforefirst();
// cleanup the queue
while (queue_.size() != 0) {
free_cells_.push(queue_.front());
queue_.pop();
}
// reset the state
produce_end_ = false;
producer_sig_processed_ = true;
producer_sig_ = kProduce;
// notify consumer that all the process as been done.
lock.unlock();
consumer_cond_.notify_all();
continue;
} else {
// destroy the thread
CHECK(producer_sig_ == kDestroy);
producer_sig_processed_ = true;
produce_end_ = true;
consumer_cond_.notify_all();
return;
}
} // end of lock scope
// now without lock
produce_end_ = !next(&cell);
CHECK(cell != NULL || produce_end_);
bool notify;
{
// lockscope
std::lock_guard<std::mutex> lock(mutex_);
if (!produce_end_) {
queue_.push(cell);
} else {
if (cell != NULL) free_cells_.push(cell);
}
// put things into queue
notify = nwait_consumer_ != 0;
}
if (notify) consumer_cond_.notify_all();
}
};
producer_thread_ = new std::thread(producer_fun);
}
template<typename DType>
inline bool ThreadedIter<DType>::
Next(DType **out_dptr) {
if (producer_sig_ == kDestroy) return false;
std::unique_lock<std::mutex> lock(mutex_);
CHECK(producer_sig_ == kProduce)
<< "Make sure you call BeforeFirst not inconcurrent with Next!";
++nwait_consumer_;
consumer_cond_.wait(lock, [this]() {
return queue_.size() != 0 || produce_end_;
});
--nwait_consumer_;
if (queue_.size() != 0) {
*out_dptr = queue_.front();
queue_.pop();
bool notify = nwait_producer_ != 0 && !produce_end_;
lock.unlock();
if (notify) producer_cond_.notify_one();
return true;
} else {
CHECK(produce_end_);
return false;
}
}
template<typename DType>
inline void ThreadedIter<DType>::Recycle(DType **inout_dptr) {
bool notify;
{
std::lock_guard<std::mutex> lock(mutex_);
free_cells_.push(*inout_dptr);
*inout_dptr = NULL;
notify = nwait_producer_ != 0 && !produce_end_;
}
if (notify) producer_cond_.notify_one();
}
} // namespace dmlc
#endif // DMLC_USE_CXX11
#endif // DMLC_THREADEDITER_H_
//===== EXPANDED: ../dmlc-core/include/dmlc/threadediter.h =====
//===== EXPANDING: ../dmlc-core/src/data/row_block.h =====
/*!
* Copyright (c) 2015 by Contributors
* \file row_block.h
* \brief additional data structure to support
* RowBlock data structure
* \author Tianqi Chen
*/
#ifndef DMLC_DATA_ROW_BLOCK_H_
#define DMLC_DATA_ROW_BLOCK_H_
namespace dmlc {
namespace data {
/*!
* \brief dynamic data structure that holds
* a row block of data
* \tparam IndexType the type of index we are using
*/
template<typename IndexType>
struct RowBlockContainer {
/*! \brief array[size+1], row pointer to beginning of each rows */
std::vector<size_t> offset;
/*! \brief array[size] label of each instance */
std::vector<real_t> label;
/*! \brief array[size] weight of each instance */
std::vector<real_t> weight;
/*! \brief feature index */
std::vector<IndexType> index;
/*! \brief feature value */
std::vector<real_t> value;
/*! \brief maximum value of index */
IndexType max_index;
// constructor
RowBlockContainer(void) {
this->Clear();
}
/*! \brief convert to a row block */
inline RowBlock<IndexType> GetBlock(void) const;
/*!
* \brief write the row block to a binary stream
* \param fo output stream
*/
inline void Save(Stream *fo) const;
/*!
* \brief load row block from a binary stream
* \param fi output stream
* \return false if at end of file
*/
inline bool Load(Stream *fi);
/*! \brief clear the container */
inline void Clear(void) {
offset.clear(); offset.push_back(0);
label.clear(); index.clear(); value.clear(); weight.clear();
max_index = 0;
}
/*! \brief size of the data */
inline size_t Size(void) const {
return offset.size() - 1;
}
/*! \return estimation of memory cost of this container */
inline size_t MemCostBytes(void) const {
return offset.size() * sizeof(size_t) +
label.size() * sizeof(real_t) +
weight.size() * sizeof(real_t) +
index.size() * sizeof(IndexType) +
value.size() * sizeof(real_t);
}
/*!
* \brief push the row into container
* \param row the row to push back
* \tparam I the index type of the row
*/
template<typename I>
inline void Push(Row<I> row) {
label.push_back(row.label);
weight.push_back(row.weight);
for (size_t i = 0; i < row.length; ++i) {
CHECK_LE(row.index[i], std::numeric_limits<IndexType>::max())
<< "index exceed numeric bound of current type";
IndexType findex = static_cast<IndexType>(row.index[i]);
index.push_back(findex);
max_index = std::max(max_index, findex);
}
if (row.value != NULL) {
for (size_t i = 0; i < row.length; ++i) {
value.push_back(row.value[i]);
}
}
offset.push_back(index.size());
}
/*!
* \brief push the row block into container
* \param row the row to push back
* \tparam I the index type of the row
*/
template<typename I>
inline void Push(RowBlock<I> batch) {
size_t size = label.size();
label.resize(label.size() + batch.size);
std::memcpy(BeginPtr(label) + size, batch.label,
batch.size * sizeof(real_t));
if (batch.weight != NULL) {
weight.insert(weight.end(), batch.weight, batch.weight + batch.size);
}
size_t ndata = batch.offset[batch.size] - batch.offset[0];
index.resize(index.size() + ndata);
IndexType *ihead = BeginPtr(index) + offset.back();
for (size_t i = 0; i < ndata; ++i) {
CHECK_LE(batch.index[i], std::numeric_limits<IndexType>::max())
<< "index exceed numeric bound of current type";
IndexType findex = static_cast<IndexType>(batch.index[i]);
ihead[i] = findex;
max_index = std::max(max_index, findex);
}
if (batch.value != NULL) {
value.resize(value.size() + ndata);
std::memcpy(BeginPtr(value) + value.size() - ndata, batch.value,
ndata * sizeof(real_t));
}
size_t shift = offset[size];
offset.resize(offset.size() + batch.size);
size_t *ohead = BeginPtr(offset) + size + 1;
for (size_t i = 0; i < batch.size; ++i) {
ohead[i] = shift + batch.offset[i + 1] - batch.offset[0];
}
}
};
template<typename IndexType>
inline RowBlock<IndexType>
RowBlockContainer<IndexType>::GetBlock(void) const {
// consistency check
if (label.size()) {
CHECK_EQ(label.size() + 1, offset.size());
}
CHECK_EQ(offset.back(), index.size());
CHECK(offset.back() == value.size() || value.size() == 0);
RowBlock<IndexType> data;
data.size = offset.size() - 1;
data.offset = BeginPtr(offset);
data.label = BeginPtr(label);
data.weight = BeginPtr(weight);
data.index = BeginPtr(index);
data.value = BeginPtr(value);
return data;
}
template<typename IndexType>
inline void
RowBlockContainer<IndexType>::Save(Stream *fo) const {
fo->Write(offset);
fo->Write(label);
fo->Write(weight);
fo->Write(index);
fo->Write(value);
fo->Write(&max_index, sizeof(IndexType));
}
template<typename IndexType>
inline bool
RowBlockContainer<IndexType>::Load(Stream *fi) {
if (!fi->Read(&offset)) return false;
CHECK(fi->Read(&label)) << "Bad RowBlock format";
CHECK(fi->Read(&weight)) << "Bad RowBlock format";
CHECK(fi->Read(&index)) << "Bad RowBlock format";
CHECK(fi->Read(&value)) << "Bad RowBlock format";
CHECK(fi->Read(&max_index, sizeof(IndexType))) << "Bad RowBlock format";
return true;
}
} // namespace data
} // namespace dmlc
#endif // DMLC_DATA_ROW_BLOCK_H_
//===== EXPANDED: ../dmlc-core/src/data/row_block.h =====
namespace dmlc {
namespace data {
/*! \brief declare thread class */
template <typename IndexType>
class ThreadedParser;
/*! \brief base class for parser to parse data */
template <typename IndexType>
class ParserImpl : public Parser<IndexType> {
public:
ParserImpl() : data_ptr_(0), data_end_(0) {}
// virtual destructor
virtual ~ParserImpl() {}
/*! \brief implement next */
virtual bool Next(void) {
while (true) {
while (data_ptr_ < data_end_) {
data_ptr_ += 1;
if (data_[data_ptr_ - 1].Size() != 0) {
block_ = data_[data_ptr_ - 1].GetBlock();
return true;
}
}
if (!ParseNext(&data_)) break;
data_ptr_ = 0;
data_end_ = static_cast<IndexType>(data_.size());
}
return false;
}
virtual const RowBlock<IndexType> &Value(void) const {
return block_;
}
/*! \return size of bytes read so far */
virtual size_t BytesRead(void) const = 0;
protected:
// allow ThreadedParser to see ParseNext
friend class ThreadedParser<IndexType>;
/*!
* \brief read in next several blocks of data
* \param data vector of data to be returned
* \return true if the data is loaded, false if reach end
*/
virtual bool ParseNext(std::vector<RowBlockContainer<IndexType> > *data) = 0;
/*! \brief pointer to begin and end of data */
IndexType data_ptr_, data_end_;
/*! \brief internal data */
std::vector<RowBlockContainer<IndexType> > data_;
/*! \brief internal row block */
RowBlock<IndexType> block_;
};
#if DMLC_ENABLE_STD_THREAD
template <typename IndexType>
class ThreadedParser : public ParserImpl<IndexType> {
public:
explicit ThreadedParser(ParserImpl<IndexType> *base)
: base_(base), tmp_(NULL) {
iter_.set_max_capacity(8);
iter_.Init([base](std::vector<RowBlockContainer<IndexType> > **dptr) {
if (*dptr == NULL) {
*dptr = new std::vector<RowBlockContainer<IndexType> >();
}
return base->ParseNext(*dptr);
}, [base]() {base->BeforeFirst();});
}
virtual ~ThreadedParser(void) {
// stop things before base is deleted
iter_.Destroy();
delete base_;
delete tmp_;
}
virtual void BeforeFirst() {
iter_.BeforeFirst();
}
/*! \brief implement next */
using ParserImpl<IndexType>::data_ptr_;
using ParserImpl<IndexType>::data_end_;
virtual bool Next(void) {
while (true) {
while (data_ptr_ < data_end_) {
data_ptr_ += 1;
if ((*tmp_)[data_ptr_ - 1].Size() != 0) {
this->block_ = (*tmp_)[data_ptr_ - 1].GetBlock();
return true;
}
}
if (tmp_ != NULL) iter_.Recycle(&tmp_);
if (!iter_.Next(&tmp_)) break;
data_ptr_ = 0; data_end_ = tmp_->size();
}
return false;
}
virtual size_t BytesRead(void) const {
return base_->BytesRead();
}
protected:
virtual bool ParseNext(std::vector<RowBlockContainer<IndexType> > *data) {
LOG(FATAL) << "cannot call ParseNext"; return false;
}
private:
/*! \brief the place where we get the data */
Parser<IndexType> *base_;
/*! \brief backend threaded iterator */
ThreadedIter<std::vector<RowBlockContainer<IndexType> > > iter_;
/*! \brief current chunk of data */
std::vector<RowBlockContainer<IndexType> > *tmp_;
};
#endif // DMLC_USE_CXX11
} // namespace data
} // namespace dmlc
#endif // DMLC_DATA_PARSER_H_
//===== EXPANDED: ../dmlc-core/src/data/parser.h =====
//===== EXPANDING: ../dmlc-core/src/data/basic_row_iter.h =====
/*!
* Copyright (c) 2015 by Contributors
* \file basic_row_iter.h
* \brief row based iterator that
* loads in everything into memory and returns
* \author Tianqi Chen
*/
#ifndef DMLC_DATA_BASIC_ROW_ITER_H_
#define DMLC_DATA_BASIC_ROW_ITER_H_
//===== EXPANDING: ../dmlc-core/include/dmlc/timer.h =====
/*!
* Copyright (c) 2015 by Contributors
* \file timer.h
* \brief cross platform timer for timing
* \author Tianqi Chen
*/
#ifndef DMLC_TIMER_H_
#define DMLC_TIMER_H_
#if DMLC_USE_CXX11
#endif
#ifdef __MACH__
#endif
namespace dmlc {
/*!
* \brief return time in seconds
*/
inline double GetTime(void) {
#if DMLC_USE_CXX11
return std::chrono::duration<double>(
std::chrono::high_resolution_clock::now().time_since_epoch()).count();
#elif defined __MACH__
clock_serv_t cclock;
mach_timespec_t mts;
host_get_clock_service(mach_host_self(), CALENDAR_CLOCK, &cclock);
CHECK(clock_get_time(cclock, &mts) == 0) << "failed to get time";
mach_port_deallocate(mach_task_self(), cclock);
return static_cast<double>(mts.tv_sec) + static_cast<double>(mts.tv_nsec) * 1e-9;
#else
#if defined(__unix__) || defined(__linux__)
timespec ts;
CHECK(clock_gettime(CLOCK_REALTIME, &ts) == 0) << "failed to get time";
return static_cast<double>(ts.tv_sec) + static_cast<double>(ts.tv_nsec) * 1e-9;
#else
return static_cast<double>(time(NULL));
#endif
#endif
}
} // namespace dmlc
#endif // DMLC_TIMER_H_
//===== EXPANDED: ../dmlc-core/include/dmlc/timer.h =====
namespace dmlc {
namespace data {
/*!
* \brief basic set of row iterators that provides
* \tparam IndexType the type of index we are using
*/
template<typename IndexType>
class BasicRowIter: public RowBlockIter<IndexType> {
public:
explicit BasicRowIter(Parser<IndexType> *parser)
: at_head_(true) {
this->Init(parser);
delete parser;
}
virtual ~BasicRowIter() {}
virtual void BeforeFirst(void) {
at_head_ = true;
}
virtual bool Next(void) {
if (at_head_) {
at_head_ = false;
return true;
} else {
return false;
}
}
virtual const RowBlock<IndexType> &Value(void) const {
return row_;
}
virtual size_t NumCol(void) const {
return static_cast<size_t>(data_.max_index) + 1;
}
private:
// at head
bool at_head_;
// row block to store
RowBlock<IndexType> row_;
// back end data
RowBlockContainer<IndexType> data_;
// initialize
inline void Init(Parser<IndexType> *parser);
};
template<typename IndexType>
inline void BasicRowIter<IndexType>::Init(Parser<IndexType> *parser) {
data_.Clear();
double tstart = GetTime();
size_t bytes_expect = 10UL << 20UL;
while (parser->Next()) {
data_.Push(parser->Value());
double tdiff = GetTime() - tstart;
size_t bytes_read = parser->BytesRead();
if (bytes_read >= bytes_expect) {
bytes_read = bytes_read >> 20UL;
LOG(INFO) << bytes_read << "MB read,"
<< bytes_read / tdiff << " MB/sec";
bytes_expect += 10UL << 20UL;
}
}
row_ = data_.GetBlock();
double tdiff = GetTime() - tstart;
LOG(INFO) << "finish reading at "
<< (parser->BytesRead() >> 20UL) / tdiff
<< " MB/sec";
}
} // namespace data
} // namespace dmlc
#endif // DMLC_DATA_BASIC_ROW_ITER_H__
//===== EXPANDED: ../dmlc-core/src/data/basic_row_iter.h =====
//===== EXPANDING: ../dmlc-core/src/data/disk_row_iter.h =====
/*!
* Copyright (c) 2015 by Contributors
* \file basic_row_iter.h
* \brief row based iterator that
* caches things into disk and then load segments
* \author Tianqi Chen
*/
#ifndef DMLC_DATA_DISK_ROW_ITER_H_
#define DMLC_DATA_DISK_ROW_ITER_H_
//===== EXPANDING: ../dmlc-core/src/data/libsvm_parser.h =====
/*!
* Copyright (c) 2015 by Contributors
* \file libsvm_parser.h
* \brief iterator parser to parse libsvm format
* \author Tianqi Chen
*/
#ifndef DMLC_DATA_LIBSVM_PARSER_H_
#define DMLC_DATA_LIBSVM_PARSER_H_
//===== EXPANDING: ../dmlc-core/src/data/text_parser.h =====
/*!
* Copyright (c) 2015 by Contributors
* \file text_parser.h
* \brief iterator parser to parse text format
* \author Tianqi Chen
*/
#ifndef DMLC_DATA_TEXT_PARSER_H_
#define DMLC_DATA_TEXT_PARSER_H_
//===== EXPANDING: ../dmlc-core/include/dmlc/omp.h =====
/*!
* Copyright (c) 2015 by Contributors
* \file omp.h
* \brief header to handle OpenMP compatibility issues
*/
#ifndef DMLC_OMP_H_
#define DMLC_OMP_H_
#if defined(_OPENMP)
#else
#ifndef DISABLE_OPENMP
// use pragma message instead of warning
#pragma message("Warning: OpenMP is not available, " \
"project will be compiled into single-thread code. " \
"Use OpenMP-enabled compiler to get benefit of multi-threading.")
#endif
//! \cond Doxygen_Suppress
inline int omp_get_thread_num() { return 0; }
inline int omp_get_num_threads() { return 1; }
inline int omp_get_max_threads() { return 1; }
inline int omp_get_num_procs() { return 1; }
inline void omp_set_num_threads(int nthread) {}
#endif
// loop variable used in openmp
namespace dmlc {
#ifdef _MSC_VER
typedef int omp_uint;
typedef long omp_ulong; // NOLINT(*)
#else
typedef unsigned omp_uint;
typedef unsigned long omp_ulong; // NOLINT(*)
#endif
//! \endcond
} // namespace dmlc
#endif // DMLC_OMP_H_
//===== EXPANDED: ../dmlc-core/include/dmlc/omp.h =====
namespace dmlc {
namespace data {
/*!
* \brief Text parser that parses the input lines
* and returns rows in input data
*/
template <typename IndexType>
class TextParserBase : public ParserImpl<IndexType> {
public:
explicit TextParserBase(InputSplit *source,
int nthread)
: bytes_read_(0), source_(source) {
int maxthread;
#pragma omp parallel
{
maxthread = std::max(omp_get_num_procs() / 2 - 4, 1);
}
nthread_ = std::min(maxthread, nthread);
}
virtual ~TextParserBase() {
delete source_;
}
virtual void BeforeFirst(void) {
source_->BeforeFirst();
}
virtual size_t BytesRead(void) const {
return bytes_read_;
}
virtual bool ParseNext(std::vector<RowBlockContainer<IndexType> > *data) {
return FillData(data);
}
protected:
/*!
* \brief parse data into out
* \param begin beginning of buffer
* \param end end of buffer
*/
virtual void ParseBlock(char *begin,
char *end,
RowBlockContainer<IndexType> *out) = 0;
/*!
* \brief read in next several blocks of data
* \param data vector of data to be returned
* \return true if the data is loaded, false if reach end
*/
inline bool FillData(std::vector<RowBlockContainer<IndexType> > *data);
/*!
* \brief start from bptr, go backward and find first endof line
* \param bptr end position to go backward
* \param begin the beginning position of buffer
* \return position of first endof line going backward
*/
inline char* BackFindEndLine(char *bptr,
char *begin) {
for (; bptr != begin; --bptr) {
if (*bptr == '\n' || *bptr == '\r') return bptr;
}
return begin;
}
private:
// nthread
int nthread_;
// number of bytes readed
size_t bytes_read_;
// source split that provides the data
InputSplit *source_;
};
// implementation
template <typename IndexType>
inline bool TextParserBase<IndexType>::
FillData(std::vector<RowBlockContainer<IndexType> > *data) {
InputSplit::Blob chunk;
if (!source_->NextChunk(&chunk)) return false;
const int nthread = omp_get_max_threads();
// reserve space for data
data->resize(nthread);
bytes_read_ += chunk.size;
CHECK_NE(chunk.size, 0U);
char *head = reinterpret_cast<char*>(chunk.dptr);
#pragma omp parallel num_threads(nthread)
{
// threadid
int tid = omp_get_thread_num();
size_t nstep = (chunk.size + nthread - 1) / nthread;
size_t sbegin = std::min(tid * nstep, chunk.size);
size_t send = std::min((tid + 1) * nstep, chunk.size);
char *pbegin = BackFindEndLine(head + sbegin, head);
char *pend;
if (tid + 1 == nthread) {
pend = head + send;
} else {
pend = BackFindEndLine(head + send, head);
}
ParseBlock(pbegin, pend, &(*data)[tid]);
}
this->data_ptr_ = 0;
return true;
}
} // namespace data
} // namespace dmlc
#endif // DMLC_DATA_TEXT_PARSER_H_
//===== EXPANDED: ../dmlc-core/src/data/text_parser.h =====
//===== EXPANDING: ../dmlc-core/src/data/strtonum.h =====
/*!
*x Copyright (c) 2015 by Contributors
* \file strtonum.h
* \brief A faster implementation of strtod, ...
*/
#ifndef DMLC_DATA_STRTONUM_H_
#define DMLC_DATA_STRTONUM_H_
namespace dmlc {
namespace data {
inline bool isspace(char c) {
return (c == ' ' || c == '\t' || c == '\r' || c == '\n' || c == '\f');
}
inline bool isblank(char c) {
return (c == ' ' || c == '\t');
}
inline bool isdigit(char c) {
return (c >= '0' && c <= '9');
}
inline bool isdigitchars(char c) {
return (c >= '0' && c <= '9')
|| c == '+' || c == '-'
|| c == '.'
|| c == 'e' || c == 'E';
}
/*!
* \brief A faster version of strtof
* TODO the current version does not support INF, NAN, and hex number
*/
inline float strtof(const char *nptr, char **endptr) {
const char *p = nptr;
// Skip leading white space, if any. Not necessary
while (isspace(*p) ) ++p;
// Get sign, if any.
bool sign = true;
if (*p == '-') {
sign = false; ++p;
} else if (*p == '+') {
++p;
}
// Get digits before decimal point or exponent, if any.
float value;
for (value = 0; isdigit(*p); ++p) {
value = value * 10.0f + (*p - '0');
}
// Get digits after decimal point, if any.
if (*p == '.') {
uint64_t pow10 = 1;
uint64_t val2 = 0;
++p;
while (isdigit(*p)) {
val2 = val2 * 10 + (*p - '0');
pow10 *= 10;
++p;
}
value += static_cast<float>(
static_cast<double>(val2) / static_cast<double>(pow10));
}
// Handle exponent, if any.
if ((*p == 'e') || (*p == 'E')) {
++p;
bool frac = false;
float scale = 1.0;
unsigned expon;
// Get sign of exponent, if any.
if (*p == '-') {
frac = true;
++p;
} else if (*p == '+') {
++p;
}
// Get digits of exponent, if any.
for (expon = 0; isdigit(*p); p += 1) {
expon = expon * 10 + (*p - '0');
}
if (expon > 38) expon = 38;
// Calculate scaling factor.
while (expon >= 8) { scale *= 1E8; expon -= 8; }
while (expon > 0) { scale *= 10.0; expon -= 1; }
// Return signed and scaled floating point result.
value = frac ? (value / scale) : (value * scale);
}
if (endptr) *endptr = (char*)p; // NOLINT(*)
return sign ? value : - value;
}
/**
* \brief A faster string to integer convertor
* TODO only support base <=10
*/
template <typename V>
inline V strtoint(const char* nptr, char **endptr, int base) {
const char *p = nptr;
// Skip leading white space, if any. Not necessary
while (isspace(*p) ) ++p;
// Get sign if any
bool sign = true;
if (*p == '-') {
sign = false; ++p;
} else if (*p == '+') {
++p;
}
V value;
for (value = 0; isdigit(*p); ++p) {
value = value * base + (*p - '0');
}
if (endptr) *endptr = (char*)p; // NOLINT(*)
return sign ? value : - value;
}
template <typename V>
inline V strtouint(const char* nptr, char **endptr, int base) {
const char *p = nptr;
// Skip leading white space, if any. Not necessary
while (isspace(*p)) ++p;
// Get sign if any
bool sign = true;
if (*p == '-') {
sign = false; ++p;
} else if (*p == '+') {
++p;
}
// we are parsing unsigned, so no minus sign should be found
CHECK_EQ(sign, true);
V value;
for (value = 0; isdigit(*p); ++p) {
value = value * base + (*p - '0');
}
if (endptr) *endptr = (char*)p; // NOLINT(*)
return value;
}
inline uint64_t
strtoull(const char* nptr, char **endptr, int base) {
return strtouint<uint64_t>(nptr, endptr, base);
}
inline long atol(const char* p) { // NOLINT(*)
return strtoint<long>(p, 0, 10); // NOLINT(*)
}
inline float atof(const char *nptr) {
return strtof(nptr, 0);
}
template<typename T>
class Str2T {
public:
static inline T get(const char * begin, const char * end);
};
template<typename T>
inline T Str2Type(const char * begin, const char * end) {
return Str2T<T>::get(begin, end);
}
template<>
class Str2T<int32_t> {
public:
static inline int32_t get(const char * begin, const char * end) {
return strtoint<int>(begin, NULL, 10);
}
};
template<>
class Str2T<uint32_t> {
public:
static inline uint32_t get(const char * begin, const char * end) {
return strtouint<int>(begin, NULL, 10);
}
};
template<>
class Str2T<int64_t> {
public:
static inline int64_t get(const char * begin, const char * end) {
return strtoint<int64_t>(begin, NULL, 10);
}
};
template<>
class Str2T<uint64_t> {
public:
static inline uint64_t get(const char * begin, const char * end) {
return strtouint<uint64_t>(begin, NULL, 10);
}
};
template<>
class Str2T<float> {
public:
static inline float get(const char * begin, const char * end) {
return atof(begin);
}
};
/**
* \brief Parse colon seperated pair v1[:v2]
* \param begin: pointer to string
* \param end: one past end of string
* \param parseEnd: end string of parsed string
* \param v1: first value in the pair
* \param v2: second value in the pair
* \output number of values parsed
*/
template<typename T1, typename T2>
inline int ParsePair(const char * begin, const char * end,
const char ** endptr, T1 &v1, T2 &v2) { // NOLINT(*)
const char * p = begin;
while (p != end && !isdigitchars(*p)) ++p;
if (p == end) {
*endptr = end;
return 0;
}
const char * q = p;
while (q != end && isdigitchars(*q)) ++q;
v1 = Str2Type<T1>(p, q);
p = q;
while (p != end && isblank(*p)) ++p;
if (p == end || *p != ':') {
// only v1
*endptr = p;
return 1;
}
p++;
while (p != end && !isdigitchars(*p)) ++p;
q = p;
while (q != end && isdigitchars(*q)) ++q;
*endptr = q;
v2 = Str2Type<T2>(p, q);
return 2;
}
} // namespace data
} // namespace dmlc
#endif // DMLC_DATA_STRTONUM_H_
//===== EXPANDED: ../dmlc-core/src/data/strtonum.h =====
namespace dmlc {
namespace data {
/*!
* \brief Text parser that parses the input lines
* and returns rows in input data
*/
template <typename IndexType>
class LibSVMParser : public TextParserBase<IndexType> {
public:
explicit LibSVMParser(InputSplit *source,
int nthread)
: TextParserBase<IndexType>(source, nthread) {}
protected:
virtual void ParseBlock(char *begin,
char *end,
RowBlockContainer<IndexType> *out);
};
template <typename IndexType>
void LibSVMParser<IndexType>::
ParseBlock(char *begin,
char *end,
RowBlockContainer<IndexType> *out) {
out->Clear();
char * lbegin = begin;
char * lend = lbegin;
while (lbegin != end) {
// get line end
lend = lbegin + 1;
while (lend != end && *lend != '\n' && *lend != '\r') ++lend;
// parse label[:weight]
const char * p = lbegin;
const char * q = NULL;
real_t label;
real_t weight;
int r = ParsePair<real_t, real_t>(p, lend, &q, label, weight);
if (r < 1) {
// empty line
lbegin = lend;
continue;
}
if (r == 2) {
// has weight
out->weight.push_back(weight);
}
if (out->label.size() != 0) {
out->offset.push_back(out->index.size());
}
out->label.push_back(label);
// parse feature[:value]
p = q;
while (p != lend) {
IndexType featureId;
real_t value;
int r = ParsePair<IndexType, real_t>(p, lend, &q, featureId, value);
if (r < 1) {
p = q;
continue;
}
out->index.push_back(featureId);
if (r == 2) {
// has value
out->value.push_back(value);
}
p = q;
}
// next line
lbegin = lend;
}
if (out->label.size() != 0) {
out->offset.push_back(out->index.size());
}
CHECK(out->label.size() + 1 == out->offset.size());
}
} // namespace data
} // namespace dmlc
#endif // DMLC_DATA_LIBSVM_PARSER_H_
//===== EXPANDED: ../dmlc-core/src/data/libsvm_parser.h =====
#if DMLC_ENABLE_STD_THREAD
namespace dmlc {
namespace data {
/*!
* \brief basic set of row iterators that provides
* \tparam IndexType the type of index we are using
*/
template<typename IndexType>
class DiskRowIter: public RowBlockIter<IndexType> {
public:
// page size 64MB
static const size_t kPageSize = 64UL << 20UL;
/*!
* \brief disk row iterator constructor
* \param parser parser used to generate this
*/
explicit DiskRowIter(Parser<IndexType> *parser,
const char *cache_file,
bool reuse_cache)
: cache_file_(cache_file), fi_(NULL) {
if (reuse_cache) {
if (!TryLoadCache()) {
this->BuildCache(parser);
CHECK(TryLoadCache())
<< "failed to build cache file " << cache_file;
}
} else {
this->BuildCache(parser);
CHECK(TryLoadCache())
<< "failed to build cache file " << cache_file;
}
delete parser;
}
virtual ~DiskRowIter(void) {
iter_.Destroy();
delete fi_;
}
virtual void BeforeFirst(void) {
iter_.BeforeFirst();
}
virtual bool Next(void) {
if (iter_.Next()) {
row_ = iter_.Value().GetBlock();
return true;
} else {
return false;
}
}
virtual const RowBlock<IndexType> &Value(void) const {
return row_;
}
virtual size_t NumCol(void) const {
return num_col_;
}
private:
// file place
std::string cache_file_;
// input stream
SeekStream *fi_;
// maximum feature dimension
size_t num_col_;
// row block to store
RowBlock<IndexType> row_;
// iterator
ThreadedIter<RowBlockContainer<IndexType> > iter_;
// load disk cache file
inline bool TryLoadCache(void);
// build disk cache
inline void BuildCache(Parser<IndexType> *parser);
};
// build disk cache
template<typename IndexType>
inline bool DiskRowIter<IndexType>::TryLoadCache(void) {
SeekStream *fi = SeekStream::CreateForRead(cache_file_.c_str(), true);
if (fi == NULL) return false;
this->fi_ = fi;
iter_.Init([fi](RowBlockContainer<IndexType> **dptr) {
if (*dptr ==NULL) {
*dptr = new RowBlockContainer<IndexType>();
}
return (*dptr)->Load(fi);
},
[fi]() { fi->Seek(0); });
return true;
}
template<typename IndexType>
inline void DiskRowIter<IndexType>::
BuildCache(Parser<IndexType> *parser) {
Stream *fo = Stream::Create(cache_file_.c_str(), "w");
// back end data
RowBlockContainer<IndexType> data;
num_col_ = 0;
double tstart = GetTime();
while (parser->Next()) {
data.Push(parser->Value());
double tdiff = GetTime() - tstart;
if (data.MemCostBytes() >= kPageSize) {
size_t bytes_read = parser->BytesRead();
bytes_read = bytes_read >> 20UL;
LOG(INFO) << bytes_read << "MB read,"
<< bytes_read / tdiff << " MB/sec";
data.Save(fo);
data.Clear();
num_col_ = std::max(num_col_,
static_cast<size_t>(data.max_index) + 1);
}
}
if (data.Size() != 0) {
data.Save(fo);
}
delete fo;
double tdiff = GetTime() - tstart;
LOG(INFO) << "finish reading at %g MB/sec"
<< (parser->BytesRead() >> 20UL) / tdiff;
}
} // namespace data
} // namespace dmlc
#endif // DMLC_USE_CXX11
#endif // DMLC_DATA_DISK_ROW_ITER_H_
//===== EXPANDED: ../dmlc-core/src/data/disk_row_iter.h =====
//===== EXPANDING: ../dmlc-core/src/data/csv_parser.h =====
/*!
* Copyright (c) 2015 by Contributors
* \file csv_parser.h
* \brief iterator parser to parse csv format
* \author Tianqi Chen
*/
#ifndef DMLC_DATA_CSV_PARSER_H_
#define DMLC_DATA_CSV_PARSER_H_
namespace dmlc {
namespace data {
struct CSVParserParam : public Parameter<CSVParserParam> {
std::string format;
int label_column;
// declare parameters
DMLC_DECLARE_PARAMETER(CSVParserParam) {
DMLC_DECLARE_FIELD(format).set_default("csv")
.describe("File format.");
DMLC_DECLARE_FIELD(label_column).set_default(-1)
.describe("Column index that will put into label.");
}
};
/*!
* \brief CSVParser, parses a dense csv format.
* Currently is a dummy implementation, when label column is not specified.
* All columns are treated as real dense data.
* label will be assigned to 0.
*
* This should be extended in future to accept arguments of column types.
*/
template <typename IndexType>
class CSVParser : public TextParserBase<IndexType> {
public:
explicit CSVParser(InputSplit *source,
const std::map<std::string, std::string>& args,
int nthread)
: TextParserBase<IndexType>(source, nthread) {
param_.Init(args);
CHECK_EQ(param_.format, "csv");
}
protected:
virtual void ParseBlock(char *begin,
char *end,
RowBlockContainer<IndexType> *out);
private:
CSVParserParam param_;
};
template <typename IndexType>
void CSVParser<IndexType>::
ParseBlock(char *begin,
char *end,
RowBlockContainer<IndexType> *out) {
out->Clear();
char * lbegin = begin;
char * lend = lbegin;
while (lbegin != end) {
// get line end
lend = lbegin + 1;
while (lend != end && *lend != '\n' && *lend != '\r') ++lend;
char* p = lbegin;
int column_index = 0;
IndexType idx = 0;
float label = 0.0f;
while (p != lend) {
char *endptr;
float v = strtof(p, &endptr);
p = endptr;
if (column_index == param_.label_column) {
label = v;
} else {
out->value.push_back(v);
out->index.push_back(idx++);
}
++column_index;
while (*p != ',' && p != lend) ++p;
if (p != lend) ++p;
}
// skip empty line
while ((*lend == '\n' || *lend == '\r') && lend != end) ++lend;
lbegin = lend;
out->label.push_back(label);
out->offset.push_back(out->index.size());
}
CHECK(out->label.size() + 1 == out->offset.size());
}
} // namespace data
} // namespace dmlc
#endif // DMLC_DATA_CSV_PARSER_H_
//===== EXPANDED: ../dmlc-core/src/data/csv_parser.h =====
namespace dmlc {
/*! \brief namespace for useful input data structure */
namespace data {
template<typename IndexType>
Parser<IndexType> *
CreateLibSVMParser(const std::string& path,
const std::map<std::string, std::string>& args,
unsigned part_index,
unsigned num_parts) {
InputSplit* source = InputSplit::Create(
path.c_str(), part_index, num_parts, "text");
ParserImpl<IndexType> *parser = new LibSVMParser<IndexType>(source, 2);
#if DMLC_ENABLE_STD_THREAD
parser = new ThreadedParser<IndexType>(parser);
#endif
return parser;
}
template<typename IndexType>
Parser<IndexType> *
CreateCSVParser(const std::string& path,
const std::map<std::string, std::string>& args,
unsigned part_index,
unsigned num_parts) {
InputSplit* source = InputSplit::Create(
path.c_str(), part_index, num_parts, "text");
return new CSVParser<IndexType>(source, args, 2);
}
template<typename IndexType>
inline Parser<IndexType> *
CreateParser_(const char *uri_,
unsigned part_index,
unsigned num_parts,
const char *type) {
std::string ptype = type;
io::URISpec spec(uri_, part_index, num_parts);
if (ptype == "auto") {
if (spec.args.count("format") != 0) {
ptype = spec.args.at("format");
} else {
ptype = "libsvm";
}
}
const ParserFactoryReg<IndexType>* e =
Registry<ParserFactoryReg<IndexType> >::Get()->Find(ptype);
if (e == NULL) {
LOG(FATAL) << "Unknown data type " << ptype;
}
// create parser
return (*e->body)(spec.uri, spec.args, part_index, num_parts);
}
template<typename IndexType>
inline RowBlockIter<IndexType> *
CreateIter_(const char *uri_,
unsigned part_index,
unsigned num_parts,
const char *type) {
using namespace std;
io::URISpec spec(uri_, part_index, num_parts);
Parser<IndexType> *parser = CreateParser_<IndexType>
(spec.uri.c_str(), part_index, num_parts, type);
if (spec.cache_file.length() != 0) {
#if DMLC_ENABLE_STD_THREAD
return new DiskRowIter<IndexType>(parser, spec.cache_file.c_str(), true);
#else
LOG(FATAL) << "compile with c++0x or c++11 to enable cache file";
return NULL;
#endif
} else {
return new BasicRowIter<IndexType>(parser);
}
}
DMLC_REGISTER_PARAMETER(CSVParserParam);
} // namespace data
// template specialization
template<>
RowBlockIter<uint32_t> *
RowBlockIter<uint32_t>::Create(const char *uri,
unsigned part_index,
unsigned num_parts,
const char *type) {
return data::CreateIter_<uint32_t>(uri, part_index, num_parts, type);
}
template<>
RowBlockIter<uint64_t> *
RowBlockIter<uint64_t>::Create(const char *uri,
unsigned part_index,
unsigned num_parts,
const char *type) {
return data::CreateIter_<uint64_t>(uri, part_index, num_parts, type);
}
template<>
Parser<uint32_t> *
Parser<uint32_t>::Create(const char *uri_,
unsigned part_index,
unsigned num_parts,
const char *type) {
return data::CreateParser_<uint32_t>(uri_, part_index, num_parts, type);
}
template<>
Parser<uint64_t> *
Parser<uint64_t>::Create(const char *uri_,
unsigned part_index,
unsigned num_parts,
const char *type) {
return data::CreateParser_<uint64_t>(uri_, part_index, num_parts, type);
}
// registry
DMLC_REGISTRY_ENABLE(ParserFactoryReg<uint32_t>);
DMLC_REGISTRY_ENABLE(ParserFactoryReg<uint64_t>);
DMLC_REGISTER_DATA_PARSER(uint32_t, libsvm, data::CreateLibSVMParser<uint32_t>);
DMLC_REGISTER_DATA_PARSER(uint64_t, libsvm, data::CreateLibSVMParser<uint64_t>);
DMLC_REGISTER_DATA_PARSER(uint32_t, csv, data::CreateCSVParser<uint32_t>);
} // namespace dmlc
//===== EXPANDED: ../dmlc-core/src/data.cc =====
//===== EXPANDING: ../dmlc-core/src/io.cc =====
// Copyright by Contributors
//===== EXPANDING: ../dmlc-core/src/io/single_file_split.h =====
/*!
* Copyright (c) 2015 by Contributors
* \file single_file_split.h
* \brief base implementation of line-spliter
* \author Tianqi Chen
*/
#ifndef DMLC_IO_SINGLE_FILE_SPLIT_H_
#define DMLC_IO_SINGLE_FILE_SPLIT_H_
#if defined(__FreeBSD__)
#define fopen64 std::fopen
#endif
namespace dmlc {
namespace io {
/*!
* \brief line split implementation from single FILE
* simply returns lines of files, used for stdin
*/
class SingleFileSplit : public InputSplit {
public:
explicit SingleFileSplit(const char *fname)
: use_stdin_(false), buffer_size_(kBufferSize),
chunk_begin_(NULL), chunk_end_(NULL) {
if (!std::strcmp(fname, "stdin")) {
#ifndef DMLC_STRICT_CXX98_
use_stdin_ = true; fp_ = stdin;
#endif
}
if (!use_stdin_) {
fp_ = fopen64(fname, "rb");
CHECK(fp_ != NULL) << "SingleFileSplit: fail to open " << fname;
}
buffer_.resize(kBufferSize);
}
virtual ~SingleFileSplit(void) {
if (!use_stdin_) std::fclose(fp_);
}
virtual void BeforeFirst(void) {
fseek(fp_, 0, SEEK_SET);
}
virtual void HintChunkSize(size_t chunk_size) {
buffer_size_ = std::max(chunk_size, buffer_size_);
}
virtual size_t GetTotalSize(void) {
struct stat buf;
fstat(fileno(fp_), &buf);
return buf.st_size;
}
virtual size_t Read(void *ptr, size_t size) {
return std::fread(ptr, 1, size, fp_);
}
virtual void ResetPartition(unsigned part_index, unsigned num_parts) {
CHECK(part_index == 0 && num_parts == 1);
this->BeforeFirst();
}
virtual void Write(const void *ptr, size_t size) {
LOG(FATAL) << "InputSplit do not support write";
}
virtual bool NextRecord(Blob *out_rec) {
if (chunk_begin_ == chunk_end_) {
if (!LoadChunk()) return false;
}
char *next = FindNextRecord(chunk_begin_,
chunk_end_);
out_rec->dptr = chunk_begin_;
out_rec->size = next - chunk_begin_;
chunk_begin_ = next;
return true;
}
virtual bool NextChunk(Blob *out_chunk) {
if (chunk_begin_ == chunk_end_) {
if (!LoadChunk()) return false;
}
out_chunk->dptr = chunk_begin_;
out_chunk->size = chunk_end_ - chunk_begin_;
chunk_begin_ = chunk_end_;
return true;
}
inline bool ReadChunk(void *buf, size_t *size) {
size_t max_size = *size;
if (max_size <= overflow_.length()) {
*size = 0; return true;
}
if (overflow_.length() != 0) {
std::memcpy(buf, BeginPtr(overflow_), overflow_.length());
}
size_t olen = overflow_.length();
overflow_.resize(0);
size_t nread = this->Read(reinterpret_cast<char*>(buf) + olen,
max_size - olen);
nread += olen;
if (nread == 0) return false;
if (nread != max_size) {
*size = nread;
return true;
} else {
const char *bptr = reinterpret_cast<const char*>(buf);
// return the last position where a record starts
const char *bend = this->FindLastRecordBegin(bptr, bptr + max_size);
*size = bend - bptr;
overflow_.resize(max_size - *size);
if (overflow_.length() != 0) {
std::memcpy(BeginPtr(overflow_), bend, overflow_.length());
}
return true;
}
}
protected:
inline const char* FindLastRecordBegin(const char *begin,
const char *end) {
if (begin == end) return begin;
for (const char *p = end - 1; p != begin; --p) {
if (*p == '\n' || *p == '\r') return p + 1;
}
return begin;
}
inline char* FindNextRecord(char *begin, char *end) {
char *p;
for (p = begin; p != end; ++p) {
if (*p == '\n' || *p == '\r') break;
}
for (; p != end; ++p) {
if (*p != '\n' && *p != '\r') return p;
}
return end;
}
inline bool LoadChunk(void) {
if (buffer_.length() < buffer_size_) {
buffer_.resize(buffer_size_);
}
while (true) {
size_t size = buffer_.length();
if (!ReadChunk(BeginPtr(buffer_), &size)) return false;
if (size == 0) {
buffer_.resize(buffer_.length() * 2);
} else {
chunk_begin_ = reinterpret_cast<char *>(BeginPtr(buffer_));
chunk_end_ = chunk_begin_ + size;
break;
}
}
return true;
}
private:
// buffer size
static const size_t kBufferSize = 1 << 18UL;
// file
std::FILE *fp_;
bool use_stdin_;
// internal overflow
std::string overflow_;
// internal buffer
std::string buffer_;
// internal buffer size
size_t buffer_size_;
// beginning of chunk
char *chunk_begin_;
// end of chunk
char *chunk_end_;
};
} // namespace io
} // namespace dmlc
#endif // DMLC_IO_SINGLE_FILE_SPLIT_H_
//===== EXPANDED: ../dmlc-core/src/io/single_file_split.h =====
//===== EXPANDING: ../dmlc-core/src/io/cached_input_split.h =====
/*!
* Copyright (c) 2015 by Contributors
* \file cached_input_split.h
* \brief InputSplit that reads from an existing InputSplit
* and cache the data into local disk, the second iteration
* will be reading from the local cached data
* \author Tianqi Chen
*/
#ifndef DMLC_IO_CACHED_INPUT_SPLIT_H_
#define DMLC_IO_CACHED_INPUT_SPLIT_H_
// this code depends on c++11
#if DMLC_ENABLE_STD_THREAD
namespace dmlc {
namespace io {
/*!
* \brief InputSplit that reads from an existing InputSplit
* and cache the data into local disk, the second iteration
* will be reading from the local cached data
*/
class CachedInputSplit : public InputSplit {
public:
/*!
* \brief constructor
* \param base source input split
* \param cache_file the path to cache file
* \param reuse_exist_cache whether reuse existing cache file, if any
*/
CachedInputSplit(InputSplitBase *base,
const char *cache_file,
bool reuse_exist_cache = true)
: buffer_size_(InputSplitBase::kBufferSize),
cache_file_(cache_file),
fo_(NULL), fi_(NULL),
base_(base), tmp_chunk_(NULL),
iter_preproc_(NULL) {
if (reuse_exist_cache) {
if (!this->InitCachedIter()) {
this->InitPreprocIter();
}
} else {
this->InitPreprocIter();
}
}
// destructor
virtual ~CachedInputSplit(void) {
// NOTE delete can handle NULL ptr
// deletion order matters
delete iter_preproc_;
delete fo_;
iter_cached_.Destroy();
delete tmp_chunk_;
delete base_;
delete fi_;
}
virtual void BeforeFirst(void) {
// if preprocessing did not end
// pull data from preprocessing module
if (iter_preproc_ != NULL) {
if (tmp_chunk_ != NULL) {
iter_preproc_->Recycle(&tmp_chunk_);
}
while (iter_preproc_->Next(&tmp_chunk_)) {
iter_preproc_->Recycle(&tmp_chunk_);
}
// finalize the push out process
delete iter_preproc_;
delete fo_;
iter_preproc_ = NULL;
fo_ = NULL;
CHECK(this->InitCachedIter())
<< "Failed to initialize CachedIter";
} else {
iter_cached_.BeforeFirst();
}
if (tmp_chunk_ != NULL) {
iter_cached_.Recycle(&tmp_chunk_);
}
}
virtual void ResetPartition(unsigned part_index, unsigned num_parts) {
LOG(FATAL) << "ResetPartition is not supported in CachedInputSplit";
}
virtual void HintChunkSize(size_t chunk_size) {
buffer_size_ = std::max(chunk_size / sizeof(size_t), buffer_size_);
}
virtual size_t GetTotalSize(void) {
return base_->GetTotalSize();
}
// implement next record
virtual bool NextRecord(Blob *out_rec) {
auto *iter = iter_preproc_ != NULL ? iter_preproc_ : &iter_cached_;
if (tmp_chunk_ == NULL) {
if (!iter->Next(&tmp_chunk_)) return false;
}
while (!base_->ExtractNextRecord(out_rec, tmp_chunk_)) {
iter->Recycle(&tmp_chunk_);
if (!iter->Next(&tmp_chunk_)) return false;
}
return true;
}
// implement next chunk
virtual bool NextChunk(Blob *out_chunk) {
auto *iter = iter_preproc_ != NULL ? iter_preproc_ : &iter_cached_;
if (tmp_chunk_ == NULL) {
if (!iter->Next(&tmp_chunk_)) return false;
}
while (!base_->ExtractNextChunk(out_chunk, tmp_chunk_)) {
iter->Recycle(&tmp_chunk_);
if (!iter->Next(&tmp_chunk_)) return false;
}
return true;
}
private:
/*! \brief internal buffer size */
size_t buffer_size_;
/*! \brief cache file path */
std::string cache_file_;
/*! \brief output stream to cache file*/
dmlc::Stream *fo_;
/*! \brief input stream from cache file */
dmlc::SeekStream *fi_;
/*! \brief the place where we get the data */
InputSplitBase *base_;
/*! \brief current chunk of data */
InputSplitBase::Chunk *tmp_chunk_;
/*! \brief backend thread iterator for preprocessing */
ThreadedIter<InputSplitBase::Chunk> *iter_preproc_;
/*! \brief backend thread iterator for cache */
ThreadedIter<InputSplitBase::Chunk> iter_cached_;
/*! \brief initialize the cached iterator */
inline void InitPreprocIter(void);
/*!
* \brief initialize the cached iterator
* \return wheher the file exist and
* initialization is successful
*/
inline bool InitCachedIter(void);
};
inline void CachedInputSplit:: InitPreprocIter(void) {
fo_ = dmlc::Stream::Create(cache_file_.c_str(), "w");
iter_preproc_ = new ThreadedIter<InputSplitBase::Chunk>();
iter_preproc_->set_max_capacity(16);
iter_preproc_->Init([this](InputSplitBase::Chunk **dptr) {
if (*dptr == NULL) {
*dptr = new InputSplitBase::Chunk(buffer_size_);
}
auto *p = *dptr;
if (!p->Load(base_, buffer_size_)) return false;
// after loading, save to disk
size_t size = p->end - p->begin;
fo_->Write(&size, sizeof(size));
fo_->Write(p->begin, size);
return true;
});
}
inline bool CachedInputSplit::InitCachedIter(void) {
fi_ = dmlc::SeekStream::CreateForRead(cache_file_.c_str(), true);
if (fi_ == NULL) return false;
iter_cached_.Init([this](InputSplitBase::Chunk **dptr) {
if (*dptr == NULL) {
*dptr = new InputSplitBase::Chunk(buffer_size_);
}
auto *p = *dptr;
// read data from cache file
size_t size;
size_t nread = fi_->Read(&size, sizeof(size));
if (nread == 0) return false;
CHECK(nread == sizeof(size))
<< cache_file_ << " has invalid cache file format";
p->data.resize(size / sizeof(size_t) + 1);
p->begin = reinterpret_cast<char*>(BeginPtr(p->data));
p->end = p->begin + size;
CHECK(fi_->Read(p->begin, size) == size)
<< cache_file_ << " has invalid cache file format";
return true;
},
[this]() { fi_->Seek(0); });
return true;
}
} // namespace io
} // namespace dmlc
#endif // DMLC_USE_CXX11
#endif // DMLC_IO_CACHED_INPUT_SPLIT_H_
//===== EXPANDED: ../dmlc-core/src/io/cached_input_split.h =====
#if DMLC_USE_HDFS
#endif
#if DMLC_USE_S3
#endif
#if DMLC_USE_AZURE
#endif
namespace dmlc {
namespace io {
FileSystem *FileSystem::GetInstance(const URI &path) {
if (path.protocol == "file://" || path.protocol.length() == 0) {
return LocalFileSystem::GetInstance();
}
if (path.protocol == "hdfs://") {
#if DMLC_USE_HDFS
return HDFSFileSystem::GetInstance(path.host);
#else
LOG(FATAL) << "Please compile with DMLC_USE_HDFS=1 to use hdfs";
#endif
}
if (path.protocol == "s3://" || path.protocol == "http://" || path.protocol == "https://") {
#if DMLC_USE_S3
return S3FileSystem::GetInstance();
#else
LOG(FATAL) << "Please compile with DMLC_USE_S3=1 to use S3";
#endif
}
if (path.protocol == "azure://") {
#if DMLC_USE_AZURE
return AzureFileSystem::GetInstance();
#else
LOG(FATAL) << "Please compile with DMLC_USE_AZURE=1 to use Azure";
#endif
}
LOG(FATAL) << "unknown filesystem protocol " + path.protocol;
return NULL;
}
} // namespace io
InputSplit* InputSplit::Create(const char *uri_,
unsigned part,
unsigned nsplit,
const char *type) {
using namespace std;
using namespace dmlc::io;
// allow cachefile in format path#cachefile
io::URISpec spec(uri_, part, nsplit);
if (!strcmp(spec.uri.c_str(), "stdin")) {
return new SingleFileSplit(spec.uri.c_str());
}
CHECK(part < nsplit) << "invalid input parameter for InputSplit::Create";
URI path(spec.uri.c_str());
InputSplitBase *split = NULL;
if (!strcmp(type, "text")) {
split = new LineSplitter(FileSystem::GetInstance(path),
spec.uri.c_str(), part, nsplit);
} else if (!strcmp(type, "recordio")) {
split = new RecordIOSplitter(FileSystem::GetInstance(path),
spec.uri.c_str(), part, nsplit);
} else {
LOG(FATAL) << "unknown input split type " << type;
}
#if DMLC_ENABLE_STD_THREAD
if (spec.cache_file.length() == 0) {
return split;
} else {
return new CachedInputSplit(split, spec.cache_file.c_str());
}
#else
CHECK(spec.cache_file.length() == 0)
<< "to enable cached file, compile with c++11";
return split;
#endif
}
Stream *Stream::Create(const char *uri,
const char * const flag,
bool try_create) {
io::URI path(uri);
return io::FileSystem::
GetInstance(path)->Open(path, flag, try_create);
}
SeekStream *SeekStream::CreateForRead(const char *uri, bool try_create) {
io::URI path(uri);
return io::FileSystem::
GetInstance(path)->OpenForRead(path, try_create);
}
} // namespace dmlc
//===== EXPANDED: ../dmlc-core/src/io.cc =====
//===== EXPANDING: ../dmlc-core/src/recordio.cc =====
// Copyright by Contributors
namespace dmlc {
// implemmentation
void RecordIOWriter::WriteRecord(const void *buf, size_t size) {
CHECK(size < (1 << 29U))
<< "RecordIO only accept record less than 2^29 bytes";
const uint32_t umagic = kMagic;
// initialize the magic number, in stack
const char *magic = reinterpret_cast<const char*>(&umagic);
const char *bhead = reinterpret_cast<const char*>(buf);
uint32_t len = static_cast<uint32_t>(size);
uint32_t lower_align = (len >> 2U) << 2U;
uint32_t upper_align = ((len + 3U) >> 2U) << 2U;
uint32_t dptr = 0;
for (uint32_t i = 0; i < lower_align ; i += 4) {
// use char check for alignment safety reason
if (bhead[i] == magic[0] &&
bhead[i + 1] == magic[1] &&
bhead[i + 2] == magic[2] &&
bhead[i + 3] == magic[3]) {
uint32_t lrec = EncodeLRec(dptr == 0 ? 1U : 2U,
i - dptr);
stream_->Write(magic, 4);
stream_->Write(&lrec, sizeof(lrec));
if (i != dptr) {
stream_->Write(bhead + dptr, i - dptr);
}
dptr = i + 4;
except_counter_ += 1;
}
}
uint32_t lrec = EncodeLRec(dptr != 0 ? 3U : 0U,
len - dptr);
stream_->Write(magic, 4);
stream_->Write(&lrec, sizeof(lrec));
if (len != dptr) {
stream_->Write(bhead + dptr, len - dptr);
}
// write padded bytes
uint32_t zero = 0;
if (upper_align != len) {
stream_->Write(&zero, upper_align - len);
}
}
bool RecordIOReader::NextRecord(std::string *out_rec) {
if (end_of_stream_) return false;
const uint32_t kMagic = RecordIOWriter::kMagic;
out_rec->clear();
size_t size = 0;
while (true) {
uint32_t header[2];
size_t nread = stream_->Read(header, sizeof(header));
if (nread == 0) {
end_of_stream_ = true; return false;
}
CHECK(nread == sizeof(header)) << "Inavlid RecordIO File";
CHECK(header[0] == RecordIOWriter::kMagic) << "Invalid RecordIO File";
uint32_t cflag = RecordIOWriter::DecodeFlag(header[1]);
uint32_t len = RecordIOWriter::DecodeLength(header[1]);
uint32_t upper_align = ((len + 3U) >> 2U) << 2U;
out_rec->resize(size + upper_align);
if (upper_align != 0) {
CHECK(stream_->Read(BeginPtr(*out_rec) + size, upper_align) == upper_align)
<< "Invalid RecordIO File upper_align=" << upper_align;
}
// squeeze back
size += len; out_rec->resize(size);
if (cflag == 0U || cflag == 3U) break;
out_rec->resize(size + sizeof(kMagic));
std::memcpy(BeginPtr(*out_rec) + size, &kMagic, sizeof(kMagic));
size += sizeof(kMagic);
}
return true;
}
// helper function to find next recordio head
inline char *FindNextRecordIOHead(char *begin, char *end) {
CHECK_EQ((reinterpret_cast<size_t>(begin) & 3UL), 0U);
CHECK_EQ((reinterpret_cast<size_t>(end) & 3UL), 0U);
uint32_t *p = reinterpret_cast<uint32_t *>(begin);
uint32_t *pend = reinterpret_cast<uint32_t *>(end);
for (; p + 1 < pend; ++p) {
if (p[0] == RecordIOWriter::kMagic) {
uint32_t cflag = RecordIOWriter::DecodeFlag(p[1]);
if (cflag == 0 || cflag == 1) {
return reinterpret_cast<char*>(p);
}
}
}
return end;
}
RecordIOChunkReader::RecordIOChunkReader(InputSplit::Blob chunk,
unsigned part_index,
unsigned num_parts) {
size_t nstep = (chunk.size + num_parts - 1) / num_parts;
// align
nstep = ((nstep + 3UL) >> 2UL) << 2UL;
size_t begin = std::min(chunk.size, nstep * part_index);
size_t end = std::min(chunk.size, nstep * (part_index + 1));
char *head = reinterpret_cast<char*>(chunk.dptr);
pbegin_ = FindNextRecordIOHead(head + begin, head + chunk.size);
pend_ = FindNextRecordIOHead(head + end, head + chunk.size);
}
bool RecordIOChunkReader::NextRecord(InputSplit::Blob *out_rec) {
if (pbegin_ >= pend_) return false;
uint32_t *p = reinterpret_cast<uint32_t *>(pbegin_);
CHECK(p[0] == RecordIOWriter::kMagic);
uint32_t cflag = RecordIOWriter::DecodeFlag(p[1]);
uint32_t clen = RecordIOWriter::DecodeLength(p[1]);
if (cflag == 0) {
// skip header
out_rec->dptr = pbegin_ + 2 * sizeof(uint32_t);
// move pbegin
pbegin_ += 2 * sizeof(uint32_t) + (((clen + 3U) >> 2U) << 2U);
CHECK(pbegin_ <= pend_) << "Invalid RecordIO Format";
out_rec->size = clen;
return true;
} else {
const uint32_t kMagic = RecordIOWriter::kMagic;
// abnormal path, read into string
CHECK(cflag == 1U) << "Invalid RecordIO Format";
temp_.resize(0);
while (true) {
CHECK(pbegin_ + 2 * sizeof(uint32_t) <= pend_);
p = reinterpret_cast<uint32_t *>(pbegin_);
CHECK(p[0] == RecordIOWriter::kMagic);
cflag = RecordIOWriter::DecodeFlag(p[1]);
clen = RecordIOWriter::DecodeLength(p[1]);
size_t tsize = temp_.length();
temp_.resize(tsize + clen);
if (clen != 0) {
std::memcpy(BeginPtr(temp_) + tsize,
pbegin_ + 2 * sizeof(uint32_t),
clen);
tsize += clen;
}
pbegin_ += 2 * sizeof(uint32_t) + (((clen + 3U) >> 2U) << 2U);
if (cflag == 3U) break;
temp_.resize(tsize + sizeof(kMagic));
std::memcpy(BeginPtr(temp_) + tsize, &kMagic, sizeof(kMagic));
}
out_rec->dptr = BeginPtr(temp_);
out_rec->size = temp_.length();
return true;
}
}
} // namespace dmlc
//===== EXPANDED: ../dmlc-core/src/recordio.cc =====
//===== EXPANDED: dmlc-minimum0.cc =====
//===== EXPANDING: nnvm.cc =====
//===== EXPANDING: ../mshadow/mshadow/tensor.h =====
/*!
* Copyright (c) 2014 by Contributors
* \file tensor.h
* \brief header file of tensor data structure and functions
* This lib requires explicit memory allocation and de-allocation
* all the data structure Tensor<cpu,1>, Tensor<gpu,1> are like handles(pointers),
* no memory allocation is happening during calculation
*
* For STL style tensor, see tensor_container.h
* \author Bing Xu, Tianqi Chen
*/
#ifndef MSHADOW_TENSOR_H_
#define MSHADOW_TENSOR_H_
//===== EXPANDING: ../mshadow/mshadow/base.h =====
/*!
* Copyright (c) 2014 by Contributors
* \file base.h
* \brief definitions of base types, operators, macros functions
*
* \author Bing Xu, Tianqi Chen
*/
#ifndef MSHADOW_BASE_H_
#define MSHADOW_BASE_H_
#ifdef _MSC_VER
#ifndef _CRT_SECURE_NO_WARNINGS
#define _CRT_SECURE_NO_WARNINGS
#endif
#ifndef _CRT_SECURE_NO_DEPRECATE
#define _CRT_SECURE_NO_DEPRECATE
#endif
#define NOMINMAX
#endif
#ifdef _MSC_VER
//! \cond Doxygen_Suppress
typedef signed char int8_t;
typedef __int16 int16_t;
typedef __int32 int32_t;
typedef __int64 int64_t;
typedef unsigned char uint8_t;
typedef unsigned __int16 uint16_t;
typedef unsigned __int32 uint32_t;
typedef unsigned __int64 uint64_t;
//! \endcond
#else
#endif
// macro defintiions
/*!
* \brief if this macro is define to be 1,
* mshadow should compile without any of other libs
*/
#ifndef MSHADOW_STAND_ALONE
#define MSHADOW_STAND_ALONE 0
#endif
/*! \brief whether do padding during allocation */
#ifndef MSHADOW_ALLOC_PAD
#define MSHADOW_ALLOC_PAD true
#endif
/*!
* \brief
* x dimension of data must be bigger pad_size * ratio to be alloced padded memory,
* otherwise use tide allocation
* for example, if pad_ratio=2, GPU memory alignement size is 32,
* then we will only allocate padded memory if x dimension > 64
* set it to 0 then we will always allocate padded memory
*/
#ifndef MSHADOW_MIN_PAD_RATIO
#define MSHADOW_MIN_PAD_RATIO 2
#endif
#if MSHADOW_STAND_ALONE
#define MSHADOW_USE_CBLAS 0
#define MSHADOW_USE_MKL 0
#define MSHADOW_USE_CUDA 0
#endif
/*!
* \brief force user to use GPU stream during computation
* error will be shot when default stream NULL is used
*/
#ifndef MSHADOW_FORCE_STREAM
#define MSHADOW_FORCE_STREAM 1
#endif
/*! \brief use CBLAS for CBLAS */
#ifndef MSHADOW_USE_CBLAS
#define MSHADOW_USE_CBLAS 0
#endif
/*! \brief use MKL for BLAS */
#ifndef MSHADOW_USE_MKL
#define MSHADOW_USE_MKL 1
#endif
/*!
* \brief use CUDA support, must ensure that the cuda include path is correct,
* or directly compile using nvcc
*/
#ifndef MSHADOW_USE_CUDA
#define MSHADOW_USE_CUDA 1
#endif
/*!
* \brief use CUDNN support, must ensure that the cudnn include path is correct
*/
#ifndef MSHADOW_USE_CUDNN
#define MSHADOW_USE_CUDNN 0
#endif
/*!
* \brief seems CUDAARCH is deprecated in future NVCC
* set this to 1 if you want to use CUDA version smaller than 2.0
*/
#ifndef MSHADOW_OLD_CUDA
#define MSHADOW_OLD_CUDA 0
#endif
/*!
* \brief macro to decide existence of c++11 compiler
*/
#ifndef MSHADOW_IN_CXX11
#define MSHADOW_IN_CXX11 (defined(__GXX_EXPERIMENTAL_CXX0X__) ||\
__cplusplus >= 201103L || defined(_MSC_VER))
#endif
/*! \brief whether use SSE */
#ifndef MSHADOW_USE_SSE
#define MSHADOW_USE_SSE 1
#endif
/*! \brief whether use NVML to get dynamic info */
#ifndef MSHADOW_USE_NVML
#define MSHADOW_USE_NVML 0
#endif
// SSE is conflict with cudacc
#ifdef __CUDACC__
#undef MSHADOW_USE_SSE
#define MSHADOW_USE_SSE 0
#endif
#if MSHADOW_USE_CBLAS
extern "C" {
}
#elif MSHADOW_USE_MKL
#endif
#if MSHADOW_USE_CUDA
#endif
#if MSHADOW_USE_CUDNN == 1
#endif
#if MSHADOW_USE_NVML
#endif
// --------------------------------
// MSHADOW_XINLINE is used for inlining template code for both CUDA and CPU code
#ifdef MSHADOW_XINLINE
#error "MSHADOW_XINLINE must not be defined"
#endif
#ifdef _MSC_VER
#define MSHADOW_FORCE_INLINE __forceinline
#pragma warning(disable : 4068)
#else
#define MSHADOW_FORCE_INLINE inline __attribute__((always_inline))
#endif
#ifdef __CUDACC__
#define MSHADOW_XINLINE MSHADOW_FORCE_INLINE __device__ __host__
#else
#define MSHADOW_XINLINE MSHADOW_FORCE_INLINE
#endif
/*! \brief cpu force inline */
#define MSHADOW_CINLINE MSHADOW_FORCE_INLINE
#if defined(__GXX_EXPERIMENTAL_CXX0X) ||\
defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103L
#define MSHADOW_CONSTEXPR constexpr
#else
#define MSHADOW_CONSTEXPR const
#endif
/*!
* \brief default data type for tensor string
* in code release, change it to default_real_t
* during development, change it to empty string so that missing
* template arguments can be detected
*/
#ifndef MSHADOW_DEFAULT_DTYPE
#define MSHADOW_DEFAULT_DTYPE = default_real_t
#endif
/*!
* \brief DMLC marco for logging
*/
#ifndef MSHADOW_USE_GLOG
#define MSHADOW_USE_GLOG DMLC_USE_GLOG
#endif // MSHADOW_USE_GLOG
#if DMLC_USE_CXX11
#define MSHADOW_THROW_EXCEPTION noexcept(false)
#define MSHADOW_NO_EXCEPTION noexcept(true)
#else
#define MSHADOW_THROW_EXCEPTION
#define MSHADOW_NO_EXCEPTION
#endif
/*!
* \brief Protected cuda call in mshadow
* \param func Expression to call.
* It checks for CUDA errors after invocation of the expression.
*/
#define MSHADOW_CUDA_CALL(func) \
{ \
cudaError_t e = (func); \
if (e == cudaErrorCudartUnloading) { \
throw dmlc::Error(cudaGetErrorString(e)); \
} \
CHECK(e == cudaSuccess) \
<< "CUDA: " << cudaGetErrorString(e); \
}
/*!
* \brief Run function and catch error, log unknown error.
* \param func Expression to call.
*/
#define MSHADOW_CATCH_ERROR(func) \
{ \
try { \
(func); \
} catch (const dmlc::Error &e) { \
std::string what = e.what(); \
if (what.find("driver shutting down") == std::string::npos) { \
LOG(ERROR) << "Ignore CUDA Error " << what; \
} \
} \
}
//===== EXPANDING: ../mshadow/mshadow/half.h =====
/*!
* Copyright (c) 2015 by Contributors
* \file half.h
* \brief definition of half (float16) type.
*
* \author Junyuan Xie
*/
#ifndef MSHADOW_HALF_H_
#define MSHADOW_HALF_H_
#if (MSHADOW_USE_CUDA && CUDA_VERSION >= 7050)
#define MSHADOW_CUDA_HALF 1
#if defined(__CUDA_ARCH__)
/*! \brief __half2float_warp */
__host__ __device__ float __half2float_warp(const volatile __half& h) { /* NOLINT(*) */
__half val;
val.x = h.x;
return __half2float(val);
}
#endif
#else
#define MSHADOW_CUDA_HALF 0
#endif
/*! \brief namespace for mshadow */
namespace mshadow {
/* \brief name space for host/device portable half-precision floats */
namespace half {
#define MSHADOW_HALF_OPERATOR(RTYPE, OP) \
MSHADOW_XINLINE RTYPE operator OP (half_t a, half_t b) { \
return RTYPE(float(a) OP float(b)); /* NOLINT(*) */ \
} \
template<typename T> \
MSHADOW_XINLINE RTYPE operator OP (half_t a, T b) { \
return RTYPE(float(a) OP float(b)); /* NOLINT(*) */ \
} \
template<typename T> \
MSHADOW_XINLINE RTYPE operator OP (T a, half_t b) { \
return RTYPE(float(a) OP float(b)); /* NOLINT(*) */ \
}
#define MSHADOW_HALF_ASSIGNOP(AOP, OP) \
template<typename T> \
MSHADOW_XINLINE half_t operator AOP (const T& a) { \
return *this = half_t(float(*this) OP float(a)); /* NOLINT(*)*/ \
} \
template<typename T> \
MSHADOW_XINLINE half_t operator AOP (const volatile T& a) volatile { \
return *this = half_t(float(*this) OP float(a)); /* NOLINT(*)*/ \
}
#if (MSHADOW_CUDA_HALF && defined(__CUDA_ARCH__))
#define MSHADOW_HALF_CONVERSIONOP(T) \
MSHADOW_XINLINE operator T() const { \
return T(__half2float(cuhalf_)); /* NOLINT(*)*/ \
} \
MSHADOW_XINLINE operator T() const volatile { \
return T(__half2float_warp(cuhalf_)); /* NOLINT(*)*/ \
}
#else
#define MSHADOW_HALF_CONVERSIONOP(T) \
MSHADOW_XINLINE operator T() const { \
return T(half2float(half_)); /* NOLINT(*)*/ \
} \
MSHADOW_XINLINE operator T() const volatile { \
return T(half2float(half_)); /* NOLINT(*)*/ \
}
#endif // (MSHADOW_CUDA_HALF && defined(__CUDA_ARCH__))
class half_t {
public:
union {
uint16_t half_;
#if MSHADOW_CUDA_HALF
__half cuhalf_;
#endif // MSHADOW_CUDA_HALF
};
static MSHADOW_XINLINE half_t Binary(uint16_t value) {
half_t res;
res.half_ = value;
return res;
}
MSHADOW_XINLINE half_t() {}
#if MSHADOW_CUDA_HALF
MSHADOW_XINLINE explicit half_t(const __half& value) {
cuhalf_ = value;
}
#endif // MSHADOW_CUDA_HALF
MSHADOW_XINLINE half_t(const float& value) { constructor(value); }
MSHADOW_XINLINE explicit half_t(const double& value) { constructor(value); }
MSHADOW_XINLINE explicit half_t(const uint8_t& value) { constructor(value); }
MSHADOW_XINLINE explicit half_t(const int32_t& value) { constructor(value); }
MSHADOW_XINLINE explicit half_t(const uint32_t& value) { constructor(value); }
MSHADOW_XINLINE explicit half_t(const int64_t& value) { constructor(value); }
MSHADOW_XINLINE explicit half_t(const uint64_t& value) { constructor(value); }
MSHADOW_HALF_CONVERSIONOP(float)
MSHADOW_HALF_ASSIGNOP(+=, +)
MSHADOW_HALF_ASSIGNOP(-=, -)
MSHADOW_HALF_ASSIGNOP(*=, *)
MSHADOW_HALF_ASSIGNOP(/=, /)
MSHADOW_XINLINE half_t operator+() {
return *this;
}
MSHADOW_XINLINE half_t operator-() {
return half_t(-float(*this)); // NOLINT(*)
}
MSHADOW_XINLINE half_t operator=(const half_t& a) {
half_ = a.half_;
return a;
}
template<typename T>
MSHADOW_XINLINE half_t operator=(const T& a) {
return *this = half_t(a); /* NOLINT(*)*/
}
MSHADOW_XINLINE half_t operator=(const half_t& a) volatile {
half_ = a.half_;
return a;
}
template<typename T>
MSHADOW_XINLINE half_t operator=(const T& a) volatile {
return *this = half_t(a); /* NOLINT(*)*/
}
private:
union Bits {
float f;
int32_t si;
uint32_t ui;
};
static int const shift = 13;
static int const shiftSign = 16;
static int32_t const infN = 0x7F800000; // flt32 infinity
static int32_t const maxN = 0x477FE000; // max flt16 normal as a flt32
static int32_t const minN = 0x38800000; // min flt16 normal as a flt32
static int32_t const signN = 0x80000000; // flt32 sign bit
static int32_t const infC = infN >> shift;
static int32_t const nanN = (infC + 1) << shift; // minimum flt16 nan as a flt32
static int32_t const maxC = maxN >> shift;
static int32_t const minC = minN >> shift;
static int32_t const signC = signN >> shiftSign; // flt16 sign bit
static int32_t const mulN = 0x52000000; // (1 << 23) / minN
static int32_t const mulC = 0x33800000; // minN / (1 << (23 - shift))
static int32_t const subC = 0x003FF; // max flt32 subnormal down shifted
static int32_t const norC = 0x00400; // min flt32 normal down shifted
static int32_t const maxD = infC - maxC - 1;
static int32_t const minD = minC - subC - 1;
MSHADOW_XINLINE uint16_t float2half(const float& value) const {
Bits v, s;
v.f = value;
uint32_t sign = v.si & signN;
v.si ^= sign;
sign >>= shiftSign; // logical shift
s.si = mulN;
s.si = s.f * v.f; // correct subnormals
v.si ^= (s.si ^ v.si) & -(minN > v.si);
v.si ^= (infN ^ v.si) & -((infN > v.si) & (v.si > maxN));
v.si ^= (nanN ^ v.si) & -((nanN > v.si) & (v.si > infN));
v.ui >>= shift; // logical shift
v.si ^= ((v.si - maxD) ^ v.si) & -(v.si > maxC);
v.si ^= ((v.si - minD) ^ v.si) & -(v.si > subC);
return v.ui | sign;
}
MSHADOW_XINLINE uint16_t float2half(const volatile float& value) const volatile { // NOLINT (*)
Bits v, s;
v.f = value;
uint32_t sign = v.si & signN;
v.si ^= sign;
sign >>= shiftSign; // logical shift
s.si = mulN;
s.si = s.f * v.f; // correct subnormals
v.si ^= (s.si ^ v.si) & -(minN > v.si);
v.si ^= (infN ^ v.si) & -((infN > v.si) & (v.si > maxN));
v.si ^= (nanN ^ v.si) & -((nanN > v.si) & (v.si > infN));
v.ui >>= shift; // logical shift
v.si ^= ((v.si - maxD) ^ v.si) & -(v.si > maxC);
v.si ^= ((v.si - minD) ^ v.si) & -(v.si > subC);
return v.ui | sign;
}
MSHADOW_XINLINE float half2float(const uint16_t& value) const {
Bits v;
v.ui = value;
int32_t sign = v.si & signC;
v.si ^= sign;
sign <<= shiftSign;
v.si ^= ((v.si + minD) ^ v.si) & -(v.si > subC);
v.si ^= ((v.si + maxD) ^ v.si) & -(v.si > maxC);
Bits s;
s.si = mulC;
s.f *= v.si;
int32_t mask = -(norC > v.si);
v.si <<= shift;
v.si ^= (s.si ^ v.si) & mask;
v.si |= sign;
return v.f;
}
MSHADOW_XINLINE float half2float(const volatile uint16_t& value) const volatile { // NOLINT(*)
Bits v;
v.ui = value;
int32_t sign = v.si & signC;
v.si ^= sign;
sign <<= shiftSign;
v.si ^= ((v.si + minD) ^ v.si) & -(v.si > subC);
v.si ^= ((v.si + maxD) ^ v.si) & -(v.si > maxC);
Bits s;
s.si = mulC;
s.f *= v.si;
int32_t mask = -(norC > v.si);
v.si <<= shift;
v.si ^= (s.si ^ v.si) & mask;
v.si |= sign;
return v.f;
}
template<typename T>
MSHADOW_XINLINE void constructor(const T& value) {
#if (MSHADOW_CUDA_HALF && defined(__CUDA_ARCH__))
cuhalf_ = __float2half(float(value)); // NOLINT(*)
#else
half_ = float2half(float(value)); // NOLINT(*)
#endif // (MSHADOW_CUDA_HALF && defined(__CUDA_ARCH__))
}
};
/*! \brief overloaded + operator for half_t */
MSHADOW_HALF_OPERATOR(half_t, +)
/*! \brief overloaded - operator for half_t */
MSHADOW_HALF_OPERATOR(half_t, -)
/*! \brief overloaded * operator for half_t */
MSHADOW_HALF_OPERATOR(half_t, *)
/*! \brief overloaded / operator for half_t */
MSHADOW_HALF_OPERATOR(half_t, /)
/*! \brief overloaded > operator for half_t */
MSHADOW_HALF_OPERATOR(bool, >)
/*! \brief overloaded < operator for half_t */
MSHADOW_HALF_OPERATOR(bool, <)
/*! \brief overloaded >= operator for half_t */
MSHADOW_HALF_OPERATOR(bool, >=)
/*! \brief overloaded <= operator for half_t */
MSHADOW_HALF_OPERATOR(bool, <=)
#define MSHADOW_HALF_MIN mshadow::half::half_t::Binary(0x0400);
} // namespace half
} // namespace mshadow
#endif // MSHADOW_HALF_H_
//===== EXPANDED: ../mshadow/mshadow/half.h =====
//===== EXPANDING: ../mshadow/mshadow/logging.h =====
/*!
* Copyright (c) 2015 by Contributors
* \file logging.h
* \brief defines logging macros of dmlc
* allows use of GLOG, fall back to internal
* implementation when disabled
*/
#ifndef MSHADOW_LOGGING_H_
#define MSHADOW_LOGGING_H_
#ifndef DMLC_LOGGING_H_
#define DMLC_LOGGING_H_
namespace dmlc {
/*! \brief taken from DMLC directly */
/*!
* \brief exception class that will be thrown by
* default logger if DMLC_LOG_FATAL_THROW == 1
*/
struct Error : public std::runtime_error {
/*!
* \brief constructor
* \param s the error message
*/
explicit Error(const std::string &s) : std::runtime_error(s) {}
};
} // namespace dmlc
#if defined(_MSC_VER) && _MSC_VER < 1900
#define noexcept(a)
#endif
#if DMLC_USE_GLOG
namespace dmlc {
/*! \brief taken from DMLC directly */
inline void InitLogging(const char* argv0) {
google::InitGoogleLogging(argv0);
}
} // namespace dmlc
#else
// use a light version of glog
#if defined(_MSC_VER)
#pragma warning(disable : 4722)
#endif
namespace dmlc {
inline void InitLogging(const char* argv0) {
// DO NOTHING
}
// Always-on checking
#define CHECK(x) \
if (!(x)) \
dmlc::LogMessageFatal(__FILE__, __LINE__).stream() << "Check " \
"failed: " #x << ' '
#define CHECK_LT(x, y) CHECK((x) < (y))
#define CHECK_GT(x, y) CHECK((x) > (y))
#define CHECK_LE(x, y) CHECK((x) <= (y))
#define CHECK_GE(x, y) CHECK((x) >= (y))
#define CHECK_EQ(x, y) CHECK((x) == (y))
#define CHECK_NE(x, y) CHECK((x) != (y))
#define CHECK_NOTNULL(x) \
((x) == NULL ? dmlc::LogMessageFatal(__FILE__, __LINE__).stream() << "Check notnull: " #x << ' ', (x) : (x)) // NOLINT(*)
// Debug-only checking.
#ifdef NDEBUG
#define DCHECK(x) \
while (false) CHECK(x)
#define DCHECK_LT(x, y) \
while (false) CHECK((x) < (y))
#define DCHECK_GT(x, y) \
while (false) CHECK((x) > (y))
#define DCHECK_LE(x, y) \
while (false) CHECK((x) <= (y))
#define DCHECK_GE(x, y) \
while (false) CHECK((x) >= (y))
#define DCHECK_EQ(x, y) \
while (false) CHECK((x) == (y))
#define DCHECK_NE(x, y) \
while (false) CHECK((x) != (y))
#else
#define DCHECK(x) CHECK(x)
#define DCHECK_LT(x, y) CHECK((x) < (y))
#define DCHECK_GT(x, y) CHECK((x) > (y))
#define DCHECK_LE(x, y) CHECK((x) <= (y))
#define DCHECK_GE(x, y) CHECK((x) >= (y))
#define DCHECK_EQ(x, y) CHECK((x) == (y))
#define DCHECK_NE(x, y) CHECK((x) != (y))
#endif // NDEBUG
#define LOG_INFO dmlc::LogMessage(__FILE__, __LINE__)
#define LOG_ERROR LOG_INFO
#define LOG_WARNING LOG_INFO
#define LOG_FATAL dmlc::LogMessageFatal(__FILE__, __LINE__)
#define LOG_QFATAL LOG_FATAL
// Poor man version of VLOG
#define VLOG(x) LOG_INFO.stream()
#define LOG(severity) LOG_##severity.stream()
#define LG LOG_INFO.stream()
#define LOG_IF(severity, condition) \
!(condition) ? (void)0 : dmlc::LogMessageVoidify() & LOG(severity)
#ifdef NDEBUG
#define LOG_DFATAL LOG_ERROR
#define DFATAL ERROR
#define DLOG(severity) true ? (void)0 : dmlc::LogMessageVoidify() & LOG(severity)
#define DLOG_IF(severity, condition) \
(true || !(condition)) ? (void)0 : dmlc::LogMessageVoidify() & LOG(severity)
#else
#define LOG_DFATAL LOG_FATAL
#define DFATAL FATAL
#define DLOG(severity) LOG(severity)
#define DLOG_IF(severity, condition) LOG_IF(severity, condition)
#endif
// Poor man version of LOG_EVERY_N
#define LOG_EVERY_N(severity, n) LOG(severity)
class DateLogger {
public:
DateLogger() {
#if defined(_MSC_VER)
_tzset();
#endif
}
const char* HumanDate() {
#if defined(_MSC_VER)
_strtime_s(buffer_, sizeof(buffer_));
#else
time_t time_value = time(NULL);
struct tm now;
localtime_r(&time_value, &now);
snprintf(buffer_, sizeof(buffer_), "%02d:%02d:%02d", now.tm_hour,
now.tm_min, now.tm_sec);
#endif
return buffer_;
}
private:
char buffer_[9];
};
class LogMessage {
public:
LogMessage(const char* file, int line)
:
#ifdef __ANDROID__
log_stream_(std::cout)
#else
log_stream_(std::cerr)
#endif
{
log_stream_ << "[" << pretty_date_.HumanDate() << "] " << file << ":"
<< line << ": ";
}
~LogMessage() { log_stream_ << "\n"; }
std::ostream& stream() { return log_stream_; }
protected:
std::ostream& log_stream_;
private:
DateLogger pretty_date_;
LogMessage(const LogMessage&);
void operator=(const LogMessage&);
};
#if DMLC_LOG_FATAL_THROW == 0
class LogMessageFatal : public LogMessage {
public:
LogMessageFatal(const char* file, int line) : LogMessage(file, line) {}
~LogMessageFatal() {
log_stream_ << "\n";
abort();
}
private:
LogMessageFatal(const LogMessageFatal&);
void operator=(const LogMessageFatal&);
};
#else
class LogMessageFatal {
public:
LogMessageFatal(const char* file, int line) {
log_stream_ << "[" << pretty_date_.HumanDate() << "] " << file << ":"
<< line << ": ";
}
std::ostringstream &stream() { return log_stream_; }
~LogMessageFatal() DMLC_THROW_EXCEPTION {
// throwing out of destructor is evil
// hopefully we can do it here
throw Error(log_stream_.str());
}
private:
std::ostringstream log_stream_;
DateLogger pretty_date_;
LogMessageFatal(const LogMessageFatal&);
void operator=(const LogMessageFatal&);
};
#endif
// This class is used to explicitly ignore values in the conditional
// logging macros. This avoids compiler warnings like "value computed
// is not used" and "statement has no effect".
class LogMessageVoidify {
public:
LogMessageVoidify() {}
// This has to be an operator with a precedence lower than << but
// higher than "?:". See its usage.
void operator&(std::ostream&) {}
};
} // namespace dmlc
#endif
#endif // DMLC_LOGGING_H_
#endif // MSHADOW_LOGGING_H_
//===== EXPANDED: ../mshadow/mshadow/logging.h =====
/*! \brief namespace for mshadow */
namespace mshadow {
/*! \brief buffer size for each random number generator */
const unsigned kRandBufferSize = 1000000;
/*! \brief pi */
const float kPi = 3.1415926f;
/*! \brief type that will be used for index */
typedef unsigned index_t;
#ifdef _WIN32
/*! \brief openmp index for windows */
typedef int64_t openmp_index_t;
#else
/*! \brief openmp index for linux */
typedef index_t openmp_index_t;
#endif
/*! \brief float point type that will be used in default by mshadow */
typedef float default_real_t;
/*! \brief data type flag */
enum TypeFlag {
kFloat32,
kFloat64,
kFloat16,
kUint8,
kInt32
};
template<typename DType>
struct DataType;
template<>
struct DataType<float> {
static const int kFlag = kFloat32;
#if (MSHADOW_USE_CUDA && MSHADOW_USE_CUDNN == 1)
static const cudnnDataType_t kCudnnFlag = CUDNN_DATA_FLOAT;
typedef float ScaleType;
#endif
};
template<>
struct DataType<double> {
static const int kFlag = kFloat64;
#if (MSHADOW_USE_CUDA && MSHADOW_USE_CUDNN == 1)
static const cudnnDataType_t kCudnnFlag = CUDNN_DATA_DOUBLE;
typedef double ScaleType;
#endif
};
template<>
struct DataType<half::half_t> {
static const int kFlag = kFloat16;
#if (MSHADOW_USE_CUDA && MSHADOW_USE_CUDNN == 1)
static const cudnnDataType_t kCudnnFlag = CUDNN_DATA_HALF;
typedef float ScaleType;
#endif
};
template<>
struct DataType<uint8_t> {
static const int kFlag = kUint8;
};
template<>
struct DataType<int32_t> {
static const int kFlag = kInt32;
};
/*! \brief type enum value for default real type */
const int default_type_flag = DataType<default_real_t>::kFlag;
/*! layout flag */
enum LayoutFlag {
kNCHW = 0,
kNHWC,
kCHWN,
kNCDHW = 1 << 5,
kNDHWC,
kCDHWN
};
template<int layout>
struct LayoutType;
template<>
struct LayoutType<kNCHW> {
static const index_t kNdim = 4;
#if (MSHADOW_USE_CUDA && MSHADOW_USE_CUDNN == 1 && CUDNN_MAJOR >= 4)
static const cudnnTensorFormat_t kCudnnFlag = CUDNN_TENSOR_NCHW;
#else
static const int kCudnnFlag = -1;
#endif
};
template<>
struct LayoutType<kNHWC> {
static const index_t kNdim = 4;
#if (MSHADOW_USE_CUDA && MSHADOW_USE_CUDNN == 1 && CUDNN_MAJOR >= 4)
static const cudnnTensorFormat_t kCudnnFlag = CUDNN_TENSOR_NHWC;
#else
static const int kCudnnFlag = -1;
#endif
};
/*! \brief default layout for 4d tensor */
const int default_layout = kNCHW;
template<>
struct LayoutType<kNCDHW> {
static const index_t kNdim = 5;
#if (MSHADOW_USE_CUDA && MSHADOW_USE_CUDNN == 1 && CUDNN_MAJOR >= 4)
static const cudnnTensorFormat_t kCudnnFlag = CUDNN_TENSOR_NCHW;
#else
static const int kCudnnFlag = -1;
#endif
};
template<>
struct LayoutType<kNDHWC> {
static const index_t kNdim = 5;
#if (MSHADOW_USE_CUDA && MSHADOW_USE_CUDNN == 1 && CUDNN_MAJOR >= 4)
static const cudnnTensorFormat_t kCudnnFlag = CUDNN_TENSOR_NHWC;
#else
static const int kCudnnFlag = -1;
#endif
};
/*! \brief default layout for 5d tensor */
const int default_layout_5d = kNCDHW;
/*! \brief namespace for operators */
namespace op {
// binary operator
/*! \brief mul operator */
struct mul{
/*! \brief map a, b to result using defined operation */
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a, DType b) {
return a * b;
}
};
/*! \brief plus operator */
struct plus {
/*! \brief map a, b to result using defined operation */
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a, DType b) {
return a + b;
}
};
/*! \brief minus operator */
struct minus {
/*! \brief map a, b to result using defined operation */
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a, DType b) {
return a - b;
}
};
/*! \brief divide operator */
struct div {
/*! \brief map a, b to result using defined operation */
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a, DType b) {
return a / b;
}
};
/*! \brief get rhs */
struct right {
/*! \brief map a, b to result using defined operation */
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a, DType b) {
return b;
}
};
// unary operator/ function: example
// these operators can be defined by user,
// in the same style as binary and unary operator
// to use, simply write F<op::identity>( src )
/*! \brief identity function that maps a real number to it self */
struct identity{
/*! \brief map a to result using defined operation */
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a) {
return a;
}
};
} // namespace op
/*! \brief namespace for savers */
namespace sv {
/*! \brief save to saver: = */
struct saveto {
/*! \brief save b to a using save method */
template<typename DType>
MSHADOW_XINLINE static void Save(DType &a, DType b) { // NOLINT(*)
a = b;
}
/*! \brief helper constant to use BLAS, alpha */
inline static default_real_t AlphaBLAS(void) { return 1.0f; }
/*! \brief helper constant to use BLAS, beta */
inline static default_real_t BetaBLAS(void) { return 0.0f; }
/*! \brief corresponding binary operator type */
typedef op::right OPType;
};
/*! \brief save to saver: += */
struct plusto {
/*! \brief save b to a using save method */
template<typename DType>
MSHADOW_XINLINE static void Save(DType &a, DType b) { // NOLINT(*)
a += b;
}
/*! \brief helper constant to use BLAS, alpha */
inline static default_real_t AlphaBLAS(void) { return 1.0f; }
/*! \brief helper constant to use BLAS, beta */
inline static default_real_t BetaBLAS(void) { return 1.0f; }
/*! \brief corresponding binary operator type */
typedef op::plus OPType;
};
/*! \brief minus to saver: -= */
struct minusto {
/*! \brief save b to a using save method */
template<typename DType>
MSHADOW_XINLINE static void Save(DType &a, DType b) { // NOLINT(*)
a -= b;
}
/*! \brief helper constant to use BLAS, alpha */
inline static default_real_t AlphaBLAS(void) { return -1.0f; }
/*! \brief helper constant to use BLAS, beta */
inline static default_real_t BetaBLAS(void) { return 1.0f; }
/*! \brief corresponding binary operator type */
typedef op::minus OPType;
};
/*! \brief multiply to saver: *= */
struct multo {
/*! \brief save b to a using save method */
template<typename DType>
MSHADOW_XINLINE static void Save(DType &a, DType b) { // NOLINT(*)
a *= b;
}
/*! \brief corresponding binary operator type */
typedef op::mul OPType;
};
/*! \brief divide to saver: /= */
struct divto {
/*! \brief save b to a using save method */
template<typename DType>
MSHADOW_XINLINE static void Save(DType& a, DType b) { // NOLINT(*)
a /= b;
}
/*! \brief corresponding binary operator type */
typedef op::div OPType;
};
} // namespace sv
/*! \brief namespace for potential reducer operations */
namespace red {
namespace limits {
/*!
* \brief minimum value of certain types
* \tparam DType data type
*/
template<typename DType>
MSHADOW_XINLINE DType MinValue(void);
/*! \brief minimum value of float */
template<>
MSHADOW_XINLINE float MinValue<float>(void) {
return -FLT_MAX;
}
/*! \brief minimum value of double */
template<>
MSHADOW_XINLINE double MinValue<double>(void) {
return -DBL_MAX;
}
/*! \brief minimum value of half */
template<>
MSHADOW_XINLINE half::half_t MinValue<half::half_t>(void) {
return MSHADOW_HALF_MIN;
}
/*! \brief minimum value of int */
template<>
MSHADOW_XINLINE int MinValue<int>(void) {
return INT_MIN;
}
/*! \brief minimum value of int */
template<>
MSHADOW_XINLINE uint8_t MinValue<uint8_t>(void) {
return 0;
}
} // namespace limits
/*! \brief sum reducer */
struct sum {
/*! \brief do reduction into dst */
template<typename DType>
MSHADOW_XINLINE static void Reduce(volatile DType& dst, volatile DType src) { // NOLINT(*)
dst += src;
}
/*!
*\brief calculate gradient of redres with respect to redsrc,
* redres: reduced result, redsrc: one of reduction element
*/
template<typename DType>
MSHADOW_XINLINE static DType PartialGrad(DType redres, DType redsrc) {
return 1;
}
/*!
*\brief set the initial value during reduction
*/
template<typename DType>
MSHADOW_XINLINE static void SetInitValue(DType &initv) { // NOLINT(*)
initv = 0;
}
};
/*! \brief maximum reducer */
struct maximum {
/*! \brief do reduction into dst */
template<typename DType>
MSHADOW_XINLINE static void Reduce(volatile DType& dst, volatile DType src) { // NOLINT(*)
using namespace std;
#ifdef __CUDACC__
dst = ::max(dst, src);
#else
dst = max(dst, src);
#endif // __CUDACC__
}
/*!
* \brief calculate gradient of redres with respect to redsrc,
* redres: reduced result, redsrc: one of reduction element
*/
template<typename DType>
MSHADOW_XINLINE static DType PartialGrad(DType redres, DType redsrc) {
return redres == redsrc ? 1: 0;
}
/*!
*\brief set the initial value during reduction
*/
template<typename DType>
MSHADOW_XINLINE static void SetInitValue(DType &initv) { // NOLINT(*)
initv = limits::MinValue<DType>();
}
};
/*! \brief minimum reducer */
struct minimum {
/*! \brief do reduction into dst */
template<typename DType>
MSHADOW_XINLINE static void Reduce(volatile DType& dst, volatile DType src) { // NOLINT(*)
using namespace std;
#ifdef __CUDACC__
dst = ::min(dst, src);
#else
dst = min(dst, src);
#endif // __CUDACC__
}
/*!
* \brief calculate gradient of redres with respect to redsrc,
* redres: reduced result, redsrc: one of reduction element
*/
template<typename DType>
MSHADOW_XINLINE static DType PartialGrad(DType redres, DType redsrc) {
return redres == redsrc ? 1: 0;
}
/*!
*\brief set the initial value during reduction
*/
template<typename DType>
MSHADOW_XINLINE static void SetInitValue(DType &initv) { // NOLINT(*)
initv = -limits::MinValue<DType>();
}
};
} // namespace red
#define MSHADOW_TYPE_SWITCH(type, DType, ...) \
switch (type) { \
case mshadow::kFloat32: \
{ \
typedef float DType; \
{__VA_ARGS__} \
} \
break; \
case mshadow::kFloat64: \
{ \
typedef double DType; \
{__VA_ARGS__} \
} \
break; \
case mshadow::kFloat16: \
{ \
typedef mshadow::half::half_t DType; \
{__VA_ARGS__} \
} \
break; \
case mshadow::kUint8: \
{ \
typedef uint8_t DType; \
{__VA_ARGS__} \
} \
break; \
case mshadow::kInt32: \
{ \
typedef int32_t DType; \
{__VA_ARGS__} \
} \
break; \
default: \
LOG(FATAL) << "Unknown type enum " << type; \
}
#define MSHADOW_REAL_TYPE_SWITCH(type, DType, ...) \
switch (type) { \
case mshadow::kFloat32: \
{ \
typedef float DType; \
{__VA_ARGS__} \
} \
break; \
case mshadow::kFloat64: \
{ \
typedef double DType; \
{__VA_ARGS__} \
} \
break; \
case mshadow::kFloat16: \
{ \
typedef mshadow::half::half_t DType; \
{__VA_ARGS__} \
} \
break; \
case mshadow::kUint8: \
LOG(FATAL) << "This operation only support " \
"floating point types not uint8"; \
break; \
case mshadow::kInt32: \
LOG(FATAL) << "This operation only support " \
"floating point types, not int32"; \
break; \
default: \
LOG(FATAL) << "Unknown type enum " << type; \
}
#define MSHADOW_LAYOUT_SWITCH(layout, Layout, ...) \
switch (layout) { \
case mshadow::kNCHW: \
{ \
const int Layout = kNCHW; \
{__VA_ARGS__} \
} \
break; \
case mshadow::kNHWC: \
{ \
const int Layout = kNHWC; \
{__VA_ARGS__} \
} \
break; \
case mshadow::kNCDHW: \
{ \
const int Layout = kNCDHW; \
{__VA_ARGS__} \
} \
break; \
case mshadow::kNDHWC: \
{ \
const int Layout = kNDHWC; \
{__VA_ARGS__} \
} \
break; \
default: \
LOG(FATAL) << "Unknown layout enum " << layout; \
}
/*! \brief get data type size from type enum */
inline size_t mshadow_sizeof(int type) {
int size = 0;
MSHADOW_TYPE_SWITCH(type, DType, size = sizeof(DType););
return size;
}
} // namespace mshadow
#endif // MSHADOW_BASE_H_
//===== EXPANDED: ../mshadow/mshadow/base.h =====
//===== EXPANDING: ../mshadow/mshadow/expression.h =====
/*!
* Copyright (c) 2014 by Contributors
* \file expression.h
* \brief definitions of abstract expressions and expressions template
* \author Tianqi Chen, Bing Xu
*/
#ifndef MSHADOW_EXPRESSION_H_
#define MSHADOW_EXPRESSION_H_
namespace mshadow {
/*!
* \brief namespace for abstract expressions and expressions template,
* have no dependecy on tensor.h,
* These data structure takes no charge in computations,
* they are only used to define operations and represent expression in a symbolic way
*/
namespace expr {
/*! \brief type of expressions */
namespace type {
// type expression type are defined as bitmask
// subtype relationshop kRValue < kMapper < kPull < kComplex
/*!
* \brief this expression directly correspnds to a data class,
* can be used to assign data
*/
const int kRValue = 0;
/*!
* \brief expression contains element-wise tensor operations,
* map a expression to same shape
*/
const int kMapper = 1;
/*!
* \brief expression that can be chained with other expressiones
* Usually it have function Eval(i,j) defined, which pulls the result (i, j) from input
* expression and output the result at certain position.
*/
const int kChainer = 3;
/*! \brief othercase: e.g dot product */
const int kComplex = 7;
} // namespace type
/*!
* \brief expression engine that actually interprets these expressions
* this is a function template that needed to be implemented for specific expressions
* \tparam Saver the save method
* \tparam RValue the type of RValue to be saved
* \sa namespace sv
*/
template<typename Saver, typename RValue, typename DType>
struct ExpEngine;
/*! \brief defines how expression exp can be evaluated and stored into dst */
// template<typename EType>
// inline static void Eval(RValue *dst, const EType &exp);
/*!
* \brief base class for expression
* \tparam SubType inheritated class must put their type into this parameter
* \tparam DType the data type of each element in the expression
* \tparam exp_type expression type, see namespace type
*/
template<typename SubType, typename DType, int exp_type>
struct Exp {
public:
/*! \return subtype instance of current class */
inline const SubType& self(void) const {
return *static_cast<const SubType*>(this);
}
/*! \return reference of subtype instance of current class */
inline SubType* ptrself(void) {
return static_cast<SubType*>(this);
}
};
/*!
* \brief scalar expression
* \tparam DType the data type of the scalar
*/
template<typename DType>
struct ScalarExp: public Exp<ScalarExp<DType>, DType, type::kMapper> {
/*! \brief scalar value */
DType scalar_;
/*! \brief implicit constructor, MUST NOT BE explicit */
ScalarExp(DType scalar) : scalar_(scalar) {} // NOLINT(*)
};
/*! \brief create an scalar expression */
template<typename DType>
inline ScalarExp<DType> scalar(DType s) {
return ScalarExp<DType>(s);
}
/*!
* \brief typecast expression, cast the type of elements
* \tparam DstDType the target type we want to cast into
* \tparam SrcDType the target type we want to cast from
* \tparam EType the type of the source expression
* \tparam etype the type of expression after cast
*/
template<typename DstDType, typename SrcDType, typename EType, int etype>
struct TypecastExp:
public Exp<TypecastExp<DstDType, SrcDType, EType, etype>,
DstDType, etype> {
/*! \brief expression to be typecasted */
const EType &exp;
/*! \brief constructor */
explicit TypecastExp(const EType &e) : exp(e) {}
};
/*! \brief create an scalar expression */
template<typename DstDType, typename SrcDType,
typename EType, int etype>
inline TypecastExp<DstDType, SrcDType, EType, (etype|type::kMapper)>
tcast(const Exp<EType, SrcDType, etype> &exp) {
return TypecastExp<DstDType, SrcDType, EType, (etype|type::kMapper)>(exp.self());
}
/*! \brief represent a transpose expression of a container */
template<typename EType, typename DType>
struct TransposeExp: public Exp<TransposeExp<EType, DType>,
DType, type::kChainer> {
/*! \brief expression to be transposed */
const EType &exp;
/*! \brief constructor */
explicit TransposeExp(const EType &e) : exp(e) {}
/*! \brief transpose expression */
inline const EType &T(void) const {
return exp;
}
};
/*!
* \brief base class of all rvalues
* \tparam Container the actually class of data container, e.g. Tensor1D
* \tparam DataType the element data type of each element in the container
*/
template<typename Container, typename DType>
class RValueExp: public Exp<Container, DType, type::kRValue> {
public:
/*!
*\brief transpose of a matrix
*\return transpose of current expression
*/
inline const TransposeExp<Container, DType> T(void) const {
return TransposeExp<Container, DType>(this->self());
}
/*! \brief operator overload */
inline Container &operator+=(DType s) {
ExpEngine<sv::plusto, Container, DType>::Eval(this->ptrself(), scalar<DType>(s));
return *(this->ptrself());
}
/*! \brief operator overload */
inline Container &operator-=(DType s) {
ExpEngine<sv::minusto, Container, DType>::Eval(this->ptrself(), scalar<DType>(s));
return *(this->ptrself());
}
/*! \brief operator overload */
inline Container &operator*=(DType s) {
ExpEngine<sv::multo, Container, DType>::Eval(this->ptrself(), scalar<DType>(s));
return *(this->ptrself());
}
/*! \brief operator overload */
inline Container &operator/=(DType s) {
ExpEngine<sv::divto, Container, DType>::Eval(this->ptrself(), scalar<DType>(s));
return *(this->ptrself());
}
/*! \brief operator overload */
inline Container &__assign(DType s) {
ExpEngine<sv::saveto, Container, DType>::Eval(this->ptrself(), scalar<DType>(s));
return *(this->ptrself());
}
/*! \brief we can not define container = container */
template<typename E, int etype>
inline Container &__assign(const Exp<E, DType, etype> &exp) {
ExpEngine<sv::saveto, Container, DType>::Eval(this->ptrself(), exp.self());
return *(this->ptrself());
}
/*! \brief operator overload, assign */
inline Container &__assign(const Exp<Container, DType, type::kRValue> &exp);
/*! \brief implementation of operator+= */
template<typename E, int etype>
inline Container &operator+=(const Exp<E, DType, etype> &exp) {
ExpEngine<sv::plusto, Container, DType>::Eval(this->ptrself(), exp.self());
return *(this->ptrself());
}
/*! \brief implementation of operator-= */
template<typename E, int etype>
inline Container &operator-=(const Exp<E, DType, etype> &exp) {
ExpEngine<sv::minusto, Container, DType>::Eval(this->ptrself(), exp.self());
return *(this->ptrself());
}
/*! \brief implementation of operator*= */
template<typename E, int etype>
inline Container &operator*=(const Exp<E, DType, etype> &exp) {
ExpEngine<sv::multo, Container, DType>::Eval(this->ptrself(), exp.self());
return *(this->ptrself());
}
/*! \brief implementation of operator/= */
template<typename E, int etype>
inline Container &operator/=(const Exp<E, DType, etype> &exp) {
ExpEngine<sv::divto, Container, DType>::Eval(this->ptrself(), exp.self());
return *(this->ptrself());
}
};
/*!
* \brief matrix multiplication expression dot(lhs[.T], rhs[.T])
* \tparam TA type of lhs
* \tparam TB type of rhs
* \tparam ltrans whether lhs is transposed
* \tparam rtrans whether rhs is transposed
* \tparam DType the data type of the scalar
*/
template<typename TA, typename TB, bool ltrans, bool rtrans, typename DType>
struct DotExp: public Exp<DotExp<TA, TB, ltrans, rtrans, DType>,
DType, type::kComplex> {
/*! \brief left operand */
const TA &lhs_;
/*! \brief right operand */
const TB &rhs_;
/*! \brief scale over result */
DType scale_;
/*! \brief constructor */
explicit DotExp(const TA &lhs, const TB &rhs, DType scale)
: lhs_(lhs), rhs_(rhs), scale_(scale) {}
};
// definition of dot expression
/*! \brief dot operator def */
template<typename TA, typename TB, typename DType>
inline DotExp<TA, TB, false, false, DType>
dot(const RValueExp<TA, DType> &lhs, const RValueExp<TB, DType> &rhs) {
return DotExp<TA, TB, false, false, DType>(lhs.self(), rhs.self(), DType(1.0f));
}
/*! \brief dot operator def */
template<typename TA, typename TB, typename DType>
inline DotExp<TA, TB, true, false, DType>
dot(const TransposeExp<TA, DType> &lhs, const RValueExp<TB, DType> &rhs) {
return DotExp<TA, TB, true, false, DType>(lhs.exp, rhs.self(), DType(1.0f));
}
/*! \brief dot operator def */
template<typename TA, typename TB, typename DType>
inline DotExp<TA, TB, false, true, DType>
dot(const RValueExp<TA, DType> &lhs, const TransposeExp<TB, DType> &rhs) {
return DotExp<TA, TB, false, true, DType>(lhs.self(), rhs.exp, DType(1.0f));
}
/*! \brief dot operator def */
template<typename TA, typename TB, typename DType>
inline DotExp<TA, TB, true, true, DType>
dot(const TransposeExp<TA, DType> &lhs, const TransposeExp<TB, DType> &rhs) {
return DotExp<TA, TB, true, true, DType>(lhs.exp, rhs.exp, DType(1.0f));
}
/*! \brief batch_dot operator def */
template<bool transpose_left, bool transpose_right, typename TA, typename TB, typename DType>
inline DotExp<TA, TB, transpose_left, transpose_right, DType>
batch_dot(const RValueExp<TA, DType> &lhs, const RValueExp<TB, DType> &rhs) {
return DotExp<TA, TB, transpose_left, transpose_right, DType>(
lhs.self(), rhs.self(), DType(1.0f));
}
//---------------
// TernaryMapExp
// --------------
/*!
* \brief ternary map expression
* \tparam OP operator
* \tparam TA type of item1
* \tparam TB type of item2
* \tparam etype expression type, sa namespace::type
*/
template<typename OP, typename TA, typename TB, typename TC, typename DType, int etype>
struct TernaryMapExp: public Exp<TernaryMapExp<OP, TA, TB, TC, DType, etype>,
DType, etype> {
/*! \brief first operand */
const TA &item1_;
/*! \brief second operand */
const TB &item2_;
/*! \brief third operand */
const TC &item3_;
/*! \brief constructor */
explicit TernaryMapExp(const TA &item1, const TB &item2, const TC &item3)
:item1_(item1), item2_(item2), item3_(item3) {}
};
/*! \brief make expression */
template<typename OP, typename TA, typename TB, typename TC, typename DType, int ta, int tb, int tc>
inline TernaryMapExp<OP, TA, TB, TC, DType, (ta|tb|tc|type::kMapper)>
MakeExp(const Exp<TA, DType, ta> &item1, const Exp<TB, DType, tb> &item2,
const Exp<TC, DType, tc> &item3) {
return TernaryMapExp<OP, TA, TB, TC, DType,
(ta|tb|tc|type::kMapper)>(item1.self(), item2.self(), item3.self());
}
/*!
* \brief short hand for MakeExp, usage F<op>(item1,item2,item3). create a ternary operation expression
* \param item1 first operand
* \param item2 second operand
* \param item3 third operand
* \return the result expression
* \tparam ternary operator
* \tparam TA item1 expression
* \tparam ta item1 expression type
* \tparam TB item2 expression
* \tparam tb item2 expression type
* \tparam TC item3 expression
* \tparam tc item3 expression type
* \sa mshadow::op
*/
// Ternary
template<typename OP, typename TA, typename TB, typename TC, typename DType, int ta, int tb, int tc>
inline TernaryMapExp<OP, TA, TB, TC, DType, (ta|tb|tc|type::kMapper)>
F(const Exp<TA, DType, ta> &item1, const Exp<TB, DType, tb> &item2,
const Exp<TC, DType, tc> &item3) {
return MakeExp<OP>(item1, item2, item3);
}
//---------------
// BinaryMapExp
// --------------
/*!
* \brief binary map expression lhs [op] rhs
* \tparam OP operator
* \tparam TA type of lhs
* \tparam TB type of rhs
* \tparam etype expression type, sa namespace::type
*/
template<typename OP, typename TA, typename TB, typename DType, int etype>
struct BinaryMapExp: public Exp<BinaryMapExp<OP, TA, TB, DType, etype>,
DType, etype> {
/*! \brief left operand */
const TA &lhs_;
/*! \brief right operand */
const TB &rhs_;
/*! \brief constructor */
explicit BinaryMapExp(const TA &lhs, const TB &rhs)
:lhs_(lhs), rhs_(rhs) {}
};
/*! \brief make expression */
template<typename OP, typename TA, typename TB, typename DType, int ta, int tb>
inline BinaryMapExp<OP, TA, TB, DType, (ta|tb|type::kMapper)>
MakeExp(const Exp<TA, DType, ta> &lhs, const Exp<TB, DType, tb> &rhs) {
return BinaryMapExp<OP, TA, TB, DType,
(ta|tb|type::kMapper)>(lhs.self(), rhs.self());
}
/*!
* \brief short hand for MakeExp, usage F<op>(lhs, rhs). create a binary operation expression
* \param lhs left operand
* \param rhs right operand
* \return the result expression
* \tparam binary operator
* \tparam TA lhs expression
* \tparam ta lhs expression type
* \tparam TB rhs expression
* \tparam tb rhs expression type
* \sa mshadow::op
*/
template<typename OP, typename TA, typename TB, typename DType, int ta, int tb>
inline BinaryMapExp<OP, TA, TB, DType, (ta|tb|type::kMapper)>
F(const Exp<TA, DType, ta> &lhs, const Exp<TB, DType, tb> &rhs) {
return MakeExp<OP>(lhs, rhs);
}
// operator rules
/*! \brief operator overload */
template<typename TA, typename TB, typename DType, int ta, int tb>
inline BinaryMapExp<op::plus, TA, TB, DType, (ta|tb|type::kMapper)>
operator+(const Exp<TA, DType, ta> &lhs, const Exp<TB, DType, tb> &rhs) {
return MakeExp<op::plus>(lhs, rhs);
}
/*! \brief operator overload */
template<typename TA, typename TB, typename DType, int ta, int tb>
inline BinaryMapExp<op::minus, TA, TB, DType, (ta|tb|type::kMapper)>
operator-(const Exp<TA, DType, ta> &lhs, const Exp<TB, DType, tb> &rhs) {
return MakeExp<op::minus>(lhs, rhs);
}
/*! \brief operator overload */
template<typename TA, typename TB, typename DType, int ta, int tb>
inline BinaryMapExp<op::mul, TA, TB, DType, (ta|tb|type::kMapper)>
operator*(const Exp<TA, DType, ta> &lhs, const Exp<TB, DType, tb> &rhs) {
return MakeExp<op::mul>(lhs, rhs);
}
/*! \brief operator overload */
template<typename TA, typename TB, typename DType, int ta, int tb>
inline BinaryMapExp<op::div, TA, TB, DType, (ta|tb|type::kMapper)>
operator/(const Exp<TA, DType, ta> &lhs, const Exp<TB, DType, tb> &rhs) {
return MakeExp<op::div>(lhs, rhs);
}
//---------------
// UnaryMapExp
// --------------
/*!
* \brief unary map expression op(src)
* \tparam OP operator
* \tparam TA type of src
* \tparam etype expression type, sa namespace::type
*/
template<typename OP, typename TA, typename DType, int etype>
struct UnaryMapExp: public Exp<UnaryMapExp<OP, TA, DType, etype>,
DType, etype> {
/*! \brief source expression */
const TA &src_;
/*! \brief constructor */
explicit UnaryMapExp(const TA &src) : src_(src) {}
};
/*! \brief make expression */
template<typename OP, typename TA, typename DType, int ta>
inline UnaryMapExp<OP, TA, DType, (ta|type::kMapper)>
MakeExp(const Exp<TA, DType, ta> &src) {
return UnaryMapExp<OP, TA, DType, (ta|type::kMapper)>(src.self());
}
/*!
* \brief short hand for MakeExp, usage F<op>(src), create a unary operation expression
* \param src source expression
* \return the result expression
* \tparam operator
* \tparam TA source expression
* \tparam ta source expression type
* \sa mshadow::op
*/
template<typename OP, typename TA, typename DType, int ta>
inline UnaryMapExp<OP, TA, DType, (ta|type::kMapper)>
F(const Exp<TA, DType, ta> &src) {
return MakeExp<OP>(src);
}
} // namespace expr
} // namespace mshadow
#endif // MSHADOW_EXPRESSION_H_
//===== EXPANDED: ../mshadow/mshadow/expression.h =====
namespace mshadow {
/*! \brief device name CPU */
struct cpu {
/*! \brief whether this device is CPU or not */
static const bool kDevCPU = true;
/*! \brief device flag number, identifies this device */
static const int kDevMask = 1 << 0;
};
/*! \brief device name GPU */
struct gpu {
/*! \brief whether this device is CPU or not */
static const bool kDevCPU = false;
/*! \brief device flag number, identifies this device */
static const int kDevMask = 1 << 1;
};
template<int ndim>
struct Shape;
/*!
* \brief allow string printing of the shape
* \param os the output stream
* \param shape the shape
* \return the ostream
*/
template<int ndim>
inline std::ostream &operator<<(std::ostream &os, const Shape<ndim> &shape); // NOLINT(*)
/*!
* \brief shape of a tensor
* \tparam dimension dimension of tensor
*/
template<int dimension>
struct Shape {
/*! \brief dimension of current shape */
static const int kDimension = dimension;
/*! \brief dimension of current shape minus one */
static const int kSubdim = dimension - 1;
/*! \brief storing the dimension information */
index_t shape_[kDimension];
/*! \brief default constructor, do nothing */
MSHADOW_XINLINE Shape(void) {}
/*! \brief constuctor */
MSHADOW_XINLINE Shape(const Shape<kDimension> &s) {
#pragma unroll
for (int i = 0; i < kDimension; ++i) {
this->shape_[i] = s[i];
}
}
/*!
* \brief get corresponding index
* \param idx dimension index
* \return the corresponding dimension size
*/
MSHADOW_XINLINE index_t &operator[](index_t idx) {
return shape_[idx];
}
/*!
* \brief get corresponding index
* \param idx dimension index
* \return the corresponding dimension size
*/
MSHADOW_XINLINE const index_t &operator[](index_t idx) const {
return shape_[idx];
}
/*!
* \return whether two shape equals
* \param s the shape to compare against
*/
MSHADOW_XINLINE bool operator==(const Shape<kDimension> &s) const {
#pragma unroll
for (int i = 0; i < kDimension; ++i) {
if (s.shape_[i] != this->shape_[i]) return false;
}
return true;
}
/*!
* \return whether two shape not equal
* \param s the shape to compare against
*/
MSHADOW_XINLINE bool operator!=(const Shape<kDimension> &s) const {
return !(*this == s);
}
/*!
* flatten the tensor, return a 1D shape
* \return the flat 1d shape
*/
MSHADOW_XINLINE Shape<1> FlatTo1D(void) const {
Shape<1> s;
s[0] = this->Size();
return s;
}
/*!
* flatten the higher dimension to second dimension, return a 2D shape
* \return the flat 2d shape
*/
MSHADOW_XINLINE Shape<2> FlatTo2D(void) const {
Shape<2> s;
s.shape_[1] = this->shape_[kDimension - 1];
index_t ymax = 1;
#pragma unroll
for (int i = 0; i < kDimension - 1; ++i) {
ymax *= this->shape_[i];
}
s.shape_[0] = ymax;
return s;
}
/*! \return number of valid elements */
MSHADOW_XINLINE size_t Size(void) const {
size_t size = this->shape_[0];
#pragma unroll
for (int i = 1; i < kDimension; ++i) {
size *= this->shape_[i];
}
return size;
}
/*!
* \return product shape in [dimstart,dimend)
* \param dimstart start dimension
* \param dimend end dimension
*/
MSHADOW_XINLINE index_t ProdShape(int dimstart, int dimend) const {
index_t num = 1;
#pragma unroll
for (int i = dimstart; i < dimend; ++i) {
num *= this->shape_[i];
}
return num;
}
/*!
* \brief get subshape that takes off largest dimension
v * \return subshape
*/
MSHADOW_XINLINE Shape<kSubdim> SubShape(void) const {
Shape<kSubdim> s;
// for cuda
#pragma unroll
for (int i = 0; i < kSubdim; ++i) {
s.shape_[i] = this->shape_[i + 1];
}
return s;
}
/*!
* \brief slice the shape from start to end
* \tparam dimstart start dimension
* \tparam dimend end dimension
* \return the sliced shape
*/
template<int dimstart, int dimend>
MSHADOW_XINLINE Shape<dimend - dimstart> Slice(void) const {
Shape<dimend - dimstart> s;
#pragma unroll
for (int i = dimstart; i < dimend; ++i) {
s[i - dimstart] = this->shape_[i];
}
return s;
}
//! \cond Doxygen_Suppress
template<int dim>
friend std::ostream &operator<<(std::ostream &os, const Shape<dim> &shape); // NOLINT(*)
//! \endcond
}; // Shape
//------------------------------------------------
// useful construction functions to generate shape
//-------------------------------------------------
/*!
* \brief construct a one dimension shape, stride will equal s0
* \param s0 size of dimension 0
* \return the shape construction
*/
MSHADOW_XINLINE Shape<1> Shape1(index_t s0) {
Shape<1> s; s[0] = s0;
return s;
}
/*!
* \brief construct a two dimension shape, stride will equal s0
* \param s0 size of dimension 0
* \param s1 size of dimension 1
* \return the shape construction
*/
MSHADOW_XINLINE Shape<2> Shape2(index_t s0, index_t s1) {
Shape<2> s; s[0] = s0; s[1] = s1;
return s;
}
/*!
* \brief construct a three dimension shape, stride will equal s0
* \param s0 size of dimension 0
* \param s1 size of dimension 1
* \param s2 size of dimension 2
* \return the shape construction
*/
MSHADOW_XINLINE Shape<3> Shape3(index_t s0, index_t s1, index_t s2) {
Shape<3> s;
s[0] = s0; s[1] = s1; s[2] = s2;
return s;
}
/*!
* \brief construct a four dimension shape, stride will equal s0
* \param s0 size of dimension 0
* \param s1 size of dimension 1
* \param s2 size of dimension 2
* \param s3 size of dimension 3
* \return the shape construction
*/
MSHADOW_XINLINE Shape<4> Shape4(index_t s0, index_t s1,
index_t s2, index_t s3) {
Shape<4> s;
s[0] = s0; s[1] = s1; s[2] = s2; s[3] = s3;
return s;
}
/*!
* \brief construct a five dimension shape, stride will equal s0
* \param s0 size of dimension 0
* \param s1 size of dimension 1
* \param s2 size of dimension 2
* \param s3 size of dimension 3
* \param s4 size of dimension 4
* \return the shape construction
*/
MSHADOW_XINLINE Shape<5> Shape5(index_t s0, index_t s1, index_t s2,
index_t s3, index_t s4) {
Shape<5> s;
s[0] = s0; s[1] = s1; s[2] = s2; s[3] = s3; s[4] = s4;
return s;
}
/*!
* \brief Convert shape in src_layout to shape in dst_layout
* \param src original shape
* \param src_layout layout of original shape
* \param dst_layout target layout
* \return shape in target layout
*/
inline Shape<4> ConvertLayout(const Shape<4>& src, int src_layout, int dst_layout) {
Shape<4> dst;
switch (src_layout) {
case kNCHW:
dst = src;
break;
case kNHWC:
dst[0] = src[0];
dst[2] = src[1];
dst[3] = src[2];
dst[1] = src[3];
break;
default:
LOG(FATAL) << "Invalid layout for 4d shape " << src_layout;
}
Shape<4> dst2;
switch (dst_layout) {
case kNCHW:
return dst;
case kNHWC:
dst2[0] = dst[0];
dst2[1] = dst[2];
dst2[2] = dst[3];
dst2[3] = dst[1];
break;
default:
LOG(FATAL) << "Invalid layout for 4d shape " << src_layout;
}
return dst2;
}
/*!
* \brief Convert shape in src_layout to shape in dst_layout
* \param src original shape
* \param src_layout layout of original shape
* \param dst_layout target layout
* \return shape in target layout
*/
inline Shape<5> ConvertLayout(const Shape<5>& src, int src_layout, int dst_layout) {
Shape<5> dst;
switch (src_layout) {
case kNCDHW:
dst = src;
break;
case kNDHWC:
dst[0] = src[0];
dst[2] = src[1];
dst[3] = src[2];
dst[4] = src[3];
dst[1] = src[4];
break;
default:
LOG(FATAL) << "Invalid layout for 5d shape " << src_layout;
}
Shape<5> dst2;
switch (dst_layout) {
case kNCDHW:
return dst;
case kNDHWC:
dst2[0] = dst[0];
dst2[1] = dst[2];
dst2[2] = dst[3];
dst2[3] = dst[4];
dst2[4] = dst[1];
break;
default:
LOG(FATAL) << "Invalid layout for 5d shape " << src_layout;
}
return dst2;
}
/*!
* \brief computaion stream structure, used for asynchronize computation
*/
template<typename Device>
struct Stream {
// this is only a dummy implementation for CPU
// for GPU, the actual implementation will be specialized in tensor_gpu-inl.h
/*!
* \brief wait for all the computation associated
* with this stream to complete
*/
inline void Wait(void) {}
/*!
* \brief query whether the the stream is idle
* \return true if the stream is idle and all the job have been completed
*/
inline bool CheckIdle(void) {
return true;
}
/*! \brief create a blas handle */
inline void CreateBlasHandle() {}
};
/*!
* \brief Tensor RValue, this is the super type of all kinds of possible tensors
* \tparam Container the tensor type
* \tparam Device which device the tensor is on
* \tparam dimension dimension of the tensor
* \tparam DType the type of elements in the tensor
*/
template<typename Container, typename Device, int dimension, typename DType>
struct TRValue: public expr::RValueExp<Container, DType> {
};
// more compact template
/*!
* \brief general tensor
* \tparam Device which device the tensor is on
* \tparam dimension dimension of the tensor
* \tparam DType the type of elements in the tensor
*/
template<typename Device, int dimension,
typename DType MSHADOW_DEFAULT_DTYPE>
struct Tensor: public TRValue<Tensor<Device, dimension, DType>,
Device, dimension, DType> {
public:
//--------------------------------
// struct memembers
//--------------------------------
/*! \brief whether current type lies in cpu */
static const bool kDevCPU = Device::kDevCPU;
/*! \brief dimension of subtype */
static const int kSubdim = dimension - 1;
//--------------------------------
// struct memembers
//--------------------------------
/*! \brief pointer to the data */
DType *dptr_;
/*! \brief shape of the tensor */
Shape<dimension> shape_;
/*!
* \brief storing the stride information in x dimension
* this is used to deal with pitch allocation in gpu or sse(align x dimension to 64bit) for efficiency
*/
index_t stride_;
/*!
* \brief stream where the computation lies
* stream is a device dependency concept where each computation
*/
Stream<Device> *stream_;
//--------------------------------
// functions
//--------------------------------
/*! \brief default constructor */
MSHADOW_XINLINE Tensor(void) : stream_(NULL) {}
/*! \brief constructor from shape */
MSHADOW_XINLINE Tensor(const Shape<dimension> &shape)
: shape_(shape), stream_(NULL) {}
/*! \brief constructor from data pointer and shape, without stride */
MSHADOW_XINLINE Tensor(DType *dptr, const Shape<dimension> &shape)
: dptr_(dptr), shape_(shape), stride_(shape[kSubdim]), stream_(NULL) {}
/*! \brief constructor from data pointer and shape, without stride */
MSHADOW_XINLINE Tensor(DType *dptr, const Shape<dimension> &shape,
Stream<Device> *stream)
: dptr_(dptr), shape_(shape), stride_(shape[kSubdim]), stream_(stream) {}
/*! \brief constructor from data pointer and shape */
MSHADOW_XINLINE Tensor(DType *dptr,
const Shape<dimension> &shape,
index_t stride, Stream<Device> *stream)
: dptr_(dptr), shape_(shape), stride_(stride), stream_(stream) {}
/*!
* \brief set the stream to do computation of current tensor
* \param stream the computation stream
*/
inline void set_stream(Stream<Device> *stream) {
this->stream_ = stream;
}
/*!
* \return memory cost of the tensor, including the aligned x dimension
* \tparam startdim the starting dimension
*/
template<int startdim>
MSHADOW_XINLINE size_t MemSize(void) const {
size_t memsz = this->stride_;
#pragma unroll
for (int i = startdim; i < kSubdim; ++i) {
memsz *= this->shape_[i];
}
return memsz;
}
/*!
* \return whether the tensor's memory is continuous
* x dimension same as stride
*/
MSHADOW_XINLINE bool CheckContiguous(void) const {
return this->shape_[dimension - 1] == stride_;
}
/*!
* \return memory cost of the tensor, including the aligned x dimension
*/
MSHADOW_XINLINE size_t MSize(void) const {
return this->MemSize<0>();
}
/*!
* \brief return size of i-th dimension, start counting from highest dimension
* \param idx the dimension count from the highest dimensin
* \return the size
*/
MSHADOW_XINLINE index_t size(index_t idx) const {
return shape_[idx];
}
/*!
* \brief flatten the tensor to 1 dimension
* \return tensor after flatten
*/
MSHADOW_XINLINE Tensor<Device, 1, DType> FlatTo1D(void) const {
return Tensor<Device, 1, DType>(dptr_, shape_.FlatTo1D(), stride_, stream_);
}
/*!
* \brief flatten the tensor to 2 dimension, collapse the higher dimensions together
* \return tensor after flatten
*/
MSHADOW_XINLINE Tensor<Device, 2, DType> FlatTo2D(void) const {
return Tensor<Device, 2, DType>(dptr_, shape_.FlatTo2D(), stride_, stream_);
}
/*!
* \brief get a element of dimension - 1
* \param idx index
* \return the result tensor
*/
MSHADOW_XINLINE Tensor<Device, kSubdim, DType> operator[](index_t idx) const {
return Tensor<Device, kSubdim, DType>(dptr_ + this->MemSize<1>() * idx,
shape_.SubShape(), stride_, stream_);
}
/*!
* \brief slice the tensor in highest dimension [begin,end)
* \param begin begin position of slice
* \param end end position of slice
* \return tensor after slice
*/
MSHADOW_XINLINE Tensor<Device, dimension, DType>
Slice(index_t begin, index_t end) const {
Shape<dimension> s = this->shape_;
s[0] = end - begin;
return Tensor<Device, dimension, DType>(dptr_ + this->MemSize<1>() * begin,
s, stride_, stream_);
}
/*!\brief implement the assignment of same type */
inline Tensor<Device, dimension, DType> &
operator=(const Tensor<Device, dimension, DType> &exp) {
dptr_ = exp.dptr_;
shape_ = exp.shape_;
stride_ = exp.stride_;
stream_ = exp.stream_;
return *this;
}
/*!\brief functions to fit expression template */
template<typename E, int etype>
inline Tensor<Device, dimension, DType> &
operator=(const expr::Exp<E, DType, etype> &exp) {
return this->__assign(exp);
}
/*!\brief functions to fit expression template */
inline Tensor<Device, dimension, DType> &operator=(const DType &exp) {
return this->__assign(exp);
}
};
/*
* respecialized class Tensor1D, thei is due to different implementation in operator[]
*/
template<typename Device, typename DType>
struct Tensor<Device, 1, DType>:
public TRValue<Tensor<Device, 1, DType>, Device, 1, DType> {
public:
DType *dptr_;
Shape<1> shape_;
index_t stride_;
Stream<Device> *stream_;
// constructor
MSHADOW_XINLINE Tensor(void) : stream_(NULL) {}
MSHADOW_XINLINE Tensor(const Shape<1> &shape)
: shape_(shape), stream_(NULL) {}
MSHADOW_XINLINE Tensor(DType *dptr, Shape<1> shape)
: dptr_(dptr), shape_(shape), stride_(shape[0]), stream_(NULL) {}
MSHADOW_XINLINE Tensor(DType *dptr, Shape<1> shape, Stream<Device> *stream)
: dptr_(dptr), shape_(shape), stride_(shape[0]), stream_(stream) {}
MSHADOW_XINLINE Tensor(DType *dptr, Shape<1> shape,
index_t stride, Stream<Device> *stream)
: dptr_(dptr), shape_(shape), stride_(stride), stream_(stream) {}
inline void set_stream(Stream<Device> *stream) {
this->stream_ = stream;
}
MSHADOW_XINLINE Tensor<Device, 1, DType> FlatTo1D(void) const {
return *this;
}
MSHADOW_XINLINE Tensor<Device, 2, DType> FlatTo2D(void) const {
return Tensor<Device, 2, DType>(dptr_, shape_.FlatTo2D(), stride_, stream_);
}
MSHADOW_XINLINE Tensor<Device, 1, DType> Slice(index_t begin, index_t end) const {
Shape<1> s;
s[0] = end - begin;
return Tensor<Device, 1, DType>(dptr_ + begin, s, s[0], stream_);
}
MSHADOW_XINLINE bool CheckContiguous(void) const {
return true;
}
MSHADOW_XINLINE size_t MSize(void) const {
return shape_[0];
}
MSHADOW_XINLINE index_t size(index_t i) const {
return shape_[0];
}
MSHADOW_XINLINE DType &operator[](index_t idx) {
return dptr_[idx];
}
MSHADOW_XINLINE const DType &operator[](index_t idx) const {
return dptr_[idx];
}
/*!\brief implement the assignment of same type */
inline Tensor<Device, 1, DType> &
operator=(const Tensor<Device, 1, DType> &exp) {
dptr_ = exp.dptr_;
shape_ = exp.shape_;
stride_ = exp.stride_;
stream_ = exp.stream_;
return *this;
}
template<typename E, int etype>
inline Tensor<Device, 1, DType> &
operator=(const expr::Exp<E, DType, etype> &exp) {
return this->__assign(exp);
}
inline Tensor<Device, 1, DType> &operator=(const DType &exp) {
return this->__assign(exp);
}
};
//------------------------
// Function Declarations
//-----------------------
/*!
* \brief initialize tensor engine, used to call intialization functions of dependent libs
* this function should be called before all GPU tensor operations,
* for using tensors in CPU, this call is actually not needed
* \param device_id GPU device id to be choosed
* \tparam Device the device type
*/
template<typename Device>
inline void InitTensorEngine(int device_id = 0);
/*!
* \brief Shutdown tensor engine on current device
* this function should be called after all GPU tensor operations,
* for using tensors in CPU, this call is actually not needed
* \tparam Device the device type
*/
template<typename Device>
inline void ShutdownTensorEngine(void);
/*!
* \brief set the device of current thread to work on
* \param devid the device id
* \tparam Device the device type
*/
template<typename Device>
inline void SetDevice(int devid);
/*!
* \brief create a new stream from system
* \param create_blas_handle whether create blas handle in stream
* \param create_dnn_handle whether create cudnn handle in stream
* \return a pointer to the created stream
* \tparam Device the device type
*/
template<typename Device>
inline Stream<Device> *NewStream(bool create_blas_handle,
bool create_dnn_handle);
/*! \brief default behavior: create cublas handle */
template<typename Device>
inline Stream<Device> *NewStream() {
return NewStream<Device>(true, false);
}
/*!
* \brief delete the computing stream
* \param stream the stream parameter to be deleted
*/
template<typename Device>
inline void DeleteStream(Stream<Device> *stream);
/*!
* \brief CPU/CPU: allocate space for CTensor, according to the shape in the obj
* this function is responsible to set the stride_ in each obj.shape
* \param obj the tensor object, with shape specified
* \param pad whether padding dimension 0, to make last dimension aligned,
* padding may help improve efficiency of matrix multiplications
* if true, will allocate space with stride_ that may not equals shape[0]
* if false, will allocate continuous space
* \tparam dim specify the dim of tensor
* \tparam DType type of element in tensor
*/
template<int dim, typename DType>
inline void AllocSpace(Tensor<cpu, dim, DType> *obj,
bool pad = MSHADOW_ALLOC_PAD);
/*!
* \brief CPU/CPU: allocate space for CTensor, according to the shape in the obj
* this function is responsible to set the stride_ in each obj.shape
* \param obj the tensor object, with shape specified
* \param pad whether padding dimension 0, to make last dimension aligned,
* padding may help improve efficiency of matrix multiplications
* if true, will allocate space with stride_ that may not equals shape[0]
* if false, will allocate continuous space
* \tparam dim specify the dim of tensor
* \tparam DType type of element in tensor
*/
template<int dim, typename DType>
inline void AllocSpace(Tensor<gpu, dim, DType> *obj,
bool pad = MSHADOW_ALLOC_PAD);
/*!
* \brief CPU/GPU: free the space of tensor, will set obj.dptr to NULL
* \param obj the tensor object
* \tparam dim specify the dim of tensor
* \tparam DType type of element in tensor
*/
template<int dim, typename DType>
inline void FreeSpace(Tensor<cpu, dim, DType> *obj);
/*!
* \brief CPU/GPU: free the space of tensor, will set obj.dptr to NULL
* \param obj the tensor object
* \tparam dim specify the dim of tensor
* \tparam DType type of element in tensor
*/
template<int dim, typename DType>
inline void FreeSpace(Tensor<gpu, dim, DType> *obj);
/*!
* \brief CPU/GPU: short cut to allocate and initialize a Tensor
* \param shape: shape of tensor
* \param initv: initialization value
* \param pad : padding option
* \param stream : stream of tensor
* \tparam Device device of tensor
* \tparam DType type of element in tensor
* \tparam dim dimention of tensor
* \return a new allocated tensor
* \sa AllocSpace
*/
template<typename Device, typename DType, int dim>
inline Tensor<Device, dim, DType> NewTensor(const Shape<dim> &shape,
DType initv,
bool pad = MSHADOW_ALLOC_PAD,
Stream<Device> *stream = NULL);
/*!
* \brief copy data from one tensor to another, with same shape
* \param dst target tensor
* \param src source tensor
* \param stream the stream, when specified, the copy can exhibit asynchronize behavior
* \tparam dim specify the dim of tensor
* \tparam DType type of element in tensor
*/
template<int dim, typename DType>
inline void Copy(Tensor<cpu, dim, DType> dst,
const Tensor<cpu, dim, DType> &src,
Stream<cpu> *stream = NULL);
/*!
* \brief copy data from one tensor to another, with same shape
* \param dst target tensor
* \param src source tensor
* \param stream the stream, when specified, the copy can exhibit asynchronize behavior
* \tparam dim specify the dim of tensor
* \tparam DType type of element in tensor
*/
template<int dim, typename DType>
inline void Copy(Tensor<cpu, dim, DType> dst,
const Tensor<gpu, dim, DType> &src,
Stream<gpu> *stream = NULL);
/*!
* \brief copy data from one tensor to another, with same shape
* \param dst target tensor
* \param src source tensor
* \param stream the stream, when specified, the copy can exhibit asynchronize behavior
* \tparam dim specify the dim of tensor
* \tparam DType type of element in tensor
*/
template<int dim, typename DType>
inline void Copy(Tensor<gpu, dim, DType> dst,
const Tensor<cpu, dim, DType> &src,
Stream<gpu> *stream = NULL);
/*!
* \brief copy data from one tensor to another, with same shape
* \param dst target tensor
* \param src source tensor
* \param stream the stream, when specified, the copy can exhibit asynchronize behavior
* \tparam dim specify the dim of tensor
* \tparam DType type of element in tensor
*/
template<int dim, typename DType>
inline void Copy(Tensor<gpu, dim, DType> dst,
const Tensor<gpu, dim, DType> &src,
Stream<gpu> *stream = NULL);
/*!
* \brief CPU/GPU: normalize softmax: dst[i][j] = exp(energy[i][j]) /(sum_j exp(energy[i][j]))
* \param dst destination
* \param energy input energy
*/
template<typename DType>
inline void Softmax(Tensor<cpu, 2, DType> dst, const Tensor<cpu, 2, DType> &energy);
/*!
* \brief CPU/GPU: normalize softmax: dst[i][j] = exp(energy[i][j]) /(sum_j exp(energy[i][j]))
* \param dst destination
* \param energy input energy
*/
template<typename DType>
inline void Softmax(Tensor<gpu, 2, DType> dst, const Tensor<gpu, 2, DType> &energy);
/*!
* \brief CPU/GPU: softmax gradient
* \param dst destination
* \param src source output
* \param label label info
*/
template<typename DType>
inline void SoftmaxGrad(Tensor<cpu, 2, DType> dst,
const Tensor<cpu, 2, DType> &src,
const Tensor<cpu, 1, DType> &label);
/*!
* \brief CPU/GPU: softmax gradient
* \param dst destination
* \param src source output
* \param label label info
*/
template<typename DType>
inline void SoftmaxGrad(Tensor<gpu, 2, DType> dst,
const Tensor<gpu, 2, DType> &src,
const Tensor<gpu, 1, DType> &label);
/*!
* \brief CPU/GPU: Gradient accumulate of embedding matrix.
dst[index[i]] += src[i]
Called when the featuredim of src is much larger than the batchsize
* \param dst destination
* \param index index to take
* \param src source output
*/
template<typename IndexType, typename DType>
inline void AddTakeGrad(Tensor<cpu, 2, DType> dst,
const Tensor<cpu, 1, IndexType>& index,
const Tensor<cpu, 2, DType> &src);
/*!
* \brief CPU/GPU: Gradient accumulate of embedding matrix.
dst[index[i]] += src[i]
Called when the featuredim of src is much larger than the batchsize
* \param dst destination
* \param index index to take
* \param src source output
*/
template<typename IndexType, typename DType>
inline void AddTakeGrad(Tensor<gpu, 2, DType> dst,
const Tensor<gpu, 1, IndexType>& index,
const Tensor<gpu, 2, DType> &src);
/*!
* \brief CPU/GPU: Gradient accumulate of embedding matrix.
dst[sorted[i]] += src[index[i]]
Called when the batchsize of src is larger than the featuredim
* \param dst destination
* \param sorted the sorted indices
* \param index original index of the sorted indices
* \param src source output
*/
template<typename IndexType, typename DType>
inline void AddTakeGradLargeBatch(Tensor<cpu, 2, DType> dst,
const Tensor<gpu, 1, IndexType>& sorted,
const Tensor<cpu, 1, IndexType>& index,
const Tensor<cpu, 2, DType> &src);
/*!
* \brief CPU/GPU: Gradient accumulate of embedding matrix.
dst[sorted[i]] += src[index[i]]
Called when the batchsize of src is larger than the featuredim
* \param dst destination
* \param sorted the sorted indices
* \param index original index of the sorted indices
* \param src source output
*/
template<typename IndexType, typename DType>
inline void AddTakeGradLargeBatch(Tensor<gpu, 2, DType> dst,
const Tensor<gpu, 1, IndexType>& sorted,
const Tensor<gpu, 1, IndexType>& index,
const Tensor<gpu, 2, DType> &src);
/*!
* \brief CPU/GPU: Fill the values of the destination matrix to specific rows in the source matrix.
dst[index[i]] = src[i]
Will use atomicAdd in the inner implementation and the result may not be deterministic.
* \param dst destination
* \param index the index to accumulate value
* \param src source output
*/
template<typename IndexType, typename DType>
inline void IndexFill(Tensor<cpu, 2, DType> dst,
const Tensor<cpu, 1, IndexType>& index,
const Tensor<cpu, 2, DType> &src);
/*!
* \brief CPU/GPU: Fill the values of the destination matrix to specific rows in the source matrix.
dst[index[i]] = src[i]
Will use atomicAdd in the inner implementation and the result may not be deterministic.
* \param dst destination
* \param index the index to accumulate value
* \param src source output
*/
template<typename IndexType, typename DType>
inline void IndexFill(Tensor<gpu, 2, DType> dst,
const Tensor<gpu, 1, IndexType>& index,
const Tensor<gpu, 2, DType> &src);
/*!
* \brief CPU/GPU: Sort key-value pairs stored in separate places. (Stable sort is performed!)
* \param keys the keys to sort
* \param values the values that sorts w.r.t the key
* \param is_ascend whether to sort key in ascending order
*/
template<typename KDType, typename VDType>
inline void SortByKey(Tensor<cpu, 1, KDType> keys, Tensor<cpu, 1, VDType> values,
bool is_ascend = true);
/*!
* \brief CPU/GPU: Sort key-value pairs stored in separate places. (Stable sort is performed!)
* \param keys the keys to sort
* \param values the values that sorts w.r.t the key
* \param is_ascend whether to sort key in ascending order
*/
template<typename KDType, typename VDType>
inline void SortByKey(Tensor<gpu, 1, KDType> keys, Tensor<gpu, 1, VDType> values,
bool is_ascend = true);
/*!
* \brief CPU/GPU: Sort the keys within each segment. (Stable sort is performed!)
Segments is defined as an ascending ordered vector like [0, 0, 0, 1, 1, 2, 3, 3, 3,...]
We sort separately the keys labeled by 0 and 1, 2, 3, etc.
Currently only supports sorting in ascending order !!
* \param values the data to sort
* \param segments segment indicator
*/
template<typename Device, typename VDType, typename SDType>
inline void VectorizedSort(Tensor<Device, 1, VDType> values, Tensor<Device, 1, SDType> segments);
// function declarations to support expression, no need to understand them
// these functions do not need to be directly used
/*!
* \brief CPU/GPU: map a expression to a tensor, this function calls MapPlan
* \tparam Saver specify storage method
* \tparam R specifies the storage type of the tensor
* \tparam dim dim of the tensor, during usage, there is no need to specify this parameter
* \tparam DType the type of elements in the tensor
* \tparam E specifies the expression type, not need to specify this parameter during usage
* \tparam etype expression type
* \param dst destination
* \param exp expression
* \sa namespace mshadow:sv, mshadow::op, mshadow::expr
*/
template<typename Saver, typename R, int dim,
typename DType, typename E, int etype>
inline void MapExp(TRValue<R, cpu, dim, DType> *dst,
const expr::Exp<E, DType, etype> &exp);
/*!
* \brief CPU/GPU: map a expression to a tensor, this function calls MapPlan
* \tparam Saver specify storage method
* \tparam R specifies the storage type of the tensor
* \tparam dim dim of the tensor, during usage, there is no need to specify this parameter
* \tparam DType the type of elements in the tensor
* \tparam E specifies the expression type, not need to specify this parameter during usage
* \tparam etype expression type
* \param dst destination
* \param exp expression
* \sa namespace mshadow:sv, mshadow::op, mshadow::expr
*/
template<typename Saver, typename R, int dim,
typename DType, typename E, int etype>
inline void MapExp(TRValue<R, gpu, dim, DType> *dst,
const expr::Exp<E, DType, etype> &exp);
/*!
* \brief CPU/GPU: map a expression, do reduction to 1D Tensor in lowest dimension (dimension 0)
* \tparam Saver specify storage method
* \tparam Reducer specify a reducer method
* \tparam R specifies the storage type of the tensor
* \tparam DType the type of elements in the tensor
* \tparam E specifies the expression type, not need to specify this parameter during usage
* \tparam etype expression type
* \param dst destination
* \param exp expression
* \param scale scale the result before save
* \sa namespace mshadow:sv, mshadow::op, mshadow::red, mshadow::expr
*/
template<typename Saver, typename Reducer,
typename R, typename DType, typename E, int etype>
inline void MapReduceKeepLowest(TRValue<R, cpu, 1, DType> *dst,
const expr::Exp<E, DType, etype> &exp,
DType scale = 1);
/*!
* \brief CPU/GPU: map a expression, do reduction to 1D Tensor in lowest dimension (dimension 0)
* \tparam Saver specify storage method
* \tparam Reducer specify a reducer method
* \tparam R specifies the storage type of the tensor
* \tparam DType the type of elements in the tensor
* \tparam E specifies the expression type, not need to specify this parameter during usage
* \tparam etype expression type
* \param dst destination
* \param exp expression
* \param scale scale the result before save
* \sa namespace mshadow:sv, mshadow::op, mshadow::red, mshadow::expr
*/
template<typename Saver, typename Reducer, typename R,
typename DType, typename E, int etype>
inline void MapReduceKeepLowest(TRValue<R, gpu, 1, DType> *dst,
const expr::Exp<E, DType, etype> &exp,
DType scale = 1);
/*!
* \brief CPU/GPU: map a expression, do reduction to 1D Tensor in third dimension (dimension 2)
* \tparam Saver specify storage method
* \tparam Reducer specify a reducer method
* \tparam R specifies the storage type of the tensor
* \tparam DType the type of elements in the tensor
* \tparam dimkeep the target dimension to be kept, should be larger than 0, for 0, use MapReduceKeepLowest
* \tparam E specifies the expression type, not need to specify this parameter during usage
* \tparam etype expression type
* \param dst destination
* \param exp expression
* \param scale scale the result before save
* \sa namespace mshadow:sv, mshadow::op, mshadow::red, mshadow::expr
*/
template<typename Saver, typename Reducer, int dimkeep,
typename R, typename DType, typename E, int etype>
inline void MapReduceKeepHighDim(TRValue<R, cpu, 1, DType> *dst,
const expr::Exp<E, DType, etype> &exp,
DType scale = 1);
/*!
* \brief CPU/GPU: map a expression, do reduction to 1D Tensor in third dimension (dimension 2)
* \tparam Saver specify storage method
* \tparam Reducer specify a reducer method
* \tparam R specifies the storage type of the tensor
* \tparam DType the type of elements in the tensor
* \tparam dimkeep the target dimension to be kept, should be larger than 0, for 0, use MapReduceKeepLowest
* \tparam E specifies the expression type, not need to specify this parameter during usage
* \tparam etype expression type
* \param dst destination
* \param exp expression
* \param scale scale the result before save
* \sa namespace mshadow:sv, mshadow::op, mshadow::red, mshadow::expr
*/
template<typename Saver, typename Reducer, int dimkeep,
typename R, typename DType, typename E, int etype>
inline void MapReduceKeepHighDim(TRValue<R, gpu, 1, DType> *dst,
const expr::Exp<E, DType, etype> &exp,
DType scale = 1);
/*!
* \brief CPU/GPU: 1 dimension vector dot
* \param dst Length 1 vector, used to hold the result.
* \param lhs Left operand vector
* \param rhs Right operand vector
*/
template<typename Device, typename DType>
inline void VectorDot(Tensor<Device, 1, DType> dst,
const Tensor<Device, 1, DType> &lhs,
const Tensor<Device, 1, DType> &rhs);
/*!
* \brief CPU/GPU: dst = alpha * op(lhs) op(rhs) + beta * dst
* \param dst Length 3 tensor, used to hold the result
* \param lhs Left operand vector
* \param rhs Right operand vector
* \param alpha multiplier of op(lhs)op(rhs)
* \param beta multiplier of dst
* \param workspace Workspace for casting DType* to DType** (batched-view), must have size >= 3 * batch_size
*/
template<bool transpose_left, bool transpose_right, typename Device, typename DType>
inline void BatchGEMM(Tensor<Device, 3, DType> dst,
const Tensor<Device, 3, DType> &lhs,
const Tensor<Device, 3, DType> &rhs,
DType alpha,
DType beta,
Tensor<Device, 1, DType*> workspace);
} // namespace mshadow
// include headers
//===== EXPANDING: ../mshadow/mshadow/stream_gpu-inl.h =====
/*!
* Copyright (c) 2014 by Contributors
* \file stream_gpu-inl.h
* \brief implementation of GPU code
* \author Bing Xu, Tianqi Chen
*/
#ifndef MSHADOW_STREAM_GPU_INL_H_
#define MSHADOW_STREAM_GPU_INL_H_
namespace mshadow {
#if MSHADOW_USE_CUDA == 1
// Stream alocation
// actual implementation of GPU stream in CUDA
template<>
struct Stream<gpu> {
/*! \brief handle state */
enum HandleState {
NoHandle = 0,
OwnHandle = 1,
};
/*! \brief cudaStream */
cudaStream_t stream_;
/*! \brief cublas handle */
cublasHandle_t blas_handle_;
/*! \brief cudnn handle */
#if MSHADOW_USE_CUDNN == 1
cudnnHandle_t dnn_handle_;
#endif
/*! \brief cublas handle ownership */
HandleState blas_handle_ownership_;
/*! \brief cudnn handle ownership */
HandleState dnn_handle_ownership_;
Stream(void) : stream_(0),
blas_handle_ownership_(NoHandle),
dnn_handle_ownership_(NoHandle) {}
/*!
* \brief wait for all the computation associated
* with this stream to complete
*/
inline void Wait(void) {
MSHADOW_CUDA_CALL(cudaStreamSynchronize(stream_));
}
/*!
* \brief query whether the the stream is idle
* \return true if the stream is idle and all the job have been completed
*/
inline bool CheckIdle(void) {
cudaError_t err = cudaStreamQuery(stream_);
if (err == cudaSuccess) return true;
if (err == cudaErrorNotReady) return false;
LOG(FATAL) << cudaGetErrorString(err);
return false;
}
/*!
* \brief returns actual cudaStream_t given an input GPU stream pointer
* \param stream pointer to GPU stream
*/
inline static cudaStream_t GetStream(Stream<gpu> *stream) {
if (stream == NULL) {
#if MSHADOW_FORCE_STREAM
LOG(FATAL) << "Default GPU stream was used when MSHADOW_FORCE_STREAM was on";
#endif
return 0;
} else {
return stream->stream_;
}
}
/*!
* \brief return actual cublasHandle
* \param pointer to GPU stream
*/
inline static cublasHandle_t GetBlasHandle(Stream<gpu> *stream) {
if (stream == NULL) {
return 0;
} else {
CHECK_NE(stream->blas_handle_ownership_, NoHandle)
<< "No handle exist in source stream";
return stream->blas_handle_;
}
}
/*! \brief Destory cublas handle if own it */
inline void DestoryBlasHandle() {
if (blas_handle_ownership_ == OwnHandle) {
cublasStatus_t err = cublasDestroy(blas_handle_);
blas_handle_ownership_ = NoHandle;
CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Destory cublas handle failed";
}
}
/*! \brief Destory original blas handle and create a new one */
inline void CreateBlasHandle() {
this->DestoryBlasHandle();
cublasStatus_t err = cublasCreate(&blas_handle_);
blas_handle_ownership_ = OwnHandle;
CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Create cublas handle failed";
}
// #if MSHADOW_USE_CUDNN && defined(__CUDACC__)
#if MSHADOW_USE_CUDNN == 1
inline static cudnnHandle_t GetDnnHandle(Stream<gpu> *stream) {
if (stream == NULL) {
return 0;
} else {
CHECK_NE(stream->dnn_handle_ownership_, NoHandle) << "No handle exist in source stream";
return stream->dnn_handle_;
}
}
#endif
inline void DestroyDnnHandle() {
// #if MSHADOW_USE_CUDNN && defined(__CUDACC__)
#if MSHADOW_USE_CUDNN == 1
if (dnn_handle_ownership_ == OwnHandle) {
cudnnStatus_t err = cudnnDestroy(dnn_handle_);
CHECK_EQ(err, CUDNN_STATUS_SUCCESS) << cudnnGetErrorString(err);
}
#endif
}
inline void CreateDnnHandle() {
// #if MSHADOW_USE_CUDNN == 1 && defined(__CUDACC__)
#if MSHADOW_USE_CUDNN == 1
this->DestroyDnnHandle();
cudnnStatus_t err = cudnnCreate(&dnn_handle_);
CHECK_EQ(err, CUDNN_STATUS_SUCCESS) << cudnnGetErrorString(err);
err = cudnnSetStream(dnn_handle_, stream_);
CHECK_EQ(err, CUDNN_STATUS_SUCCESS) << cudnnGetErrorString(err);
this->dnn_handle_ownership_ = OwnHandle;
#endif
}
};
template<>
inline Stream<gpu> *NewStream<gpu>(bool create_blas_handle,
bool create_dnn_handle) {
Stream<gpu> *st = new Stream<gpu>();
MSHADOW_CUDA_CALL(cudaStreamCreate(&st->stream_));
if (create_blas_handle) {
st->CreateBlasHandle();
}
if (create_dnn_handle) {
st->CreateDnnHandle();
}
return st;
}
template<>
inline void DeleteStream<gpu>(Stream<gpu> *stream) {
MSHADOW_CUDA_CALL(cudaStreamDestroy(stream->stream_));
stream->DestoryBlasHandle();
stream->DestroyDnnHandle();
delete stream;
}
#endif
} // namespace mshadow
#endif // MSHADOW_STREAM_GPU_INL_H_
//===== EXPANDED: ../mshadow/mshadow/stream_gpu-inl.h =====
//===== EXPANDING: ../mshadow/mshadow/extension.h =====
/*!
* Copyright by Contributors
* \file extension.h
* \brief some extension of expressions,
* used to support something beyond elementwise op
* \author Tianqi Chen, Bing Xu
*/
#ifndef MSHADOW_EXTENSION_H_
#define MSHADOW_EXTENSION_H_
//===== EXPANDING: ../mshadow/mshadow/expr_engine-inl.h =====
/*!
* Copyright (c) 2014 by Contributors
* \file expr_engine-inl.h
* \brief definitions of how expressions should be evaluated
* \author Tianqi Chen, Bing Xu
*/
#ifndef MSHADOW_EXPR_ENGINE_INL_H_
#define MSHADOW_EXPR_ENGINE_INL_H_
namespace mshadow {
namespace expr {
/*!
* \brief a general class that allows extension that makes tensors of some shape
* \tparam SubType type of subclass
* \tparam SrcExp source expression of the MakeTensorExp, the source of operation
* \tparam dim dimension of the expression
* \tparam DType the type of elements
*/
template<typename SubType, typename SrcExp, int dim, typename DType>
struct MakeTensorExp
: public Exp<MakeTensorExp<SubType, SrcExp, dim, DType>,
DType, type::kChainer> {
/*! \brief the shape of this expression */
Shape<dim> shape_;
/*! \brief true self of subtype */
inline const SubType& real_self(void) const{
return *static_cast<const SubType*>(this);
}
};
//----------------------------------------------------------------------
// This part of code gives plan that can be used to carry out execution
//---------------------------------------------------------------------
// Declarations of plans
template<typename ExpType, typename DType>
class Plan {
public:
/*!
* \brief evaluate the expression at index [y][x]
* to be implemented by SubType, for RValue, the return type will be DType &
*/
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const;
};
// tensor plan
template <typename Device, int dim, typename DType>
class Plan<Tensor<Device, dim, DType>, DType> {
public:
explicit Plan(const Tensor<Device, dim, DType> &t)
: dptr_(t.dptr_), stride_(t.stride_) {}
// for RValue, the return type should be reference
MSHADOW_XINLINE DType &REval(index_t y, index_t x) {
return dptr_[y * stride_ + x];
}
// const evaluation
MSHADOW_XINLINE const DType &Eval(index_t y, index_t x) const {
return dptr_[y * stride_ + x];
}
private:
DType *dptr_;
index_t stride_;
};
// special evaluation case for 1d tensor, no stride
template <typename Device, typename DType>
class Plan<Tensor<Device, 1, DType>, DType> {
public:
explicit Plan(const Tensor<Device, 1, DType> &t) : dptr_(t.dptr_) {}
MSHADOW_XINLINE DType &REval(index_t y, index_t x) {
return dptr_[x];
}
MSHADOW_XINLINE const DType &Eval(index_t y, index_t x) const {
return dptr_[x];
}
private:
DType *dptr_;
};
// scalar
template<typename DType>
class Plan<ScalarExp<DType>, DType> {
public:
explicit Plan(DType scalar) : scalar_(scalar) {}
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const {
return scalar_;
}
private:
DType scalar_;
};
// unary expression
template<typename DstDType, typename SrcDType,
typename EType, int etype>
class Plan<TypecastExp<DstDType, SrcDType, EType, etype>, DstDType> {
public:
explicit Plan(const Plan<EType, SrcDType> &src) : src_(src) {}
MSHADOW_XINLINE DstDType Eval(index_t y, index_t x) const {
return DstDType(src_.Eval(y, x)); // NOLINT(*)
}
private:
Plan<EType, SrcDType> src_;
};
// ternary expression
template<typename OP, typename TA, typename TB, typename TC, int etype, typename DType>
class Plan<TernaryMapExp<OP, TA, TB, TC, DType, etype>, DType> {
public:
explicit Plan(const Plan<TA, DType> &item1, const Plan<TB, DType> &item2,
const Plan<TC, DType> &item3)
: item1_(item1), item2_(item2), item3_(item3) {}
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const {
return OP::Map(item1_.Eval(y, x), item2_.Eval(y, x), item3_.Eval(y, x));
}
private:
Plan<TA, DType> item1_;
Plan<TB, DType> item2_;
Plan<TC, DType> item3_;
};
// binary expression
template<typename OP, typename TA, typename TB, int etype, typename DType>
class Plan<BinaryMapExp<OP, TA, TB, DType, etype>, DType> {
public:
explicit Plan(const Plan<TA, DType> &lhs, const Plan<TB, DType> &rhs)
: lhs_(lhs), rhs_(rhs) {}
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const {
return OP::Map(lhs_.Eval(y, x), rhs_.Eval(y, x));
}
private:
Plan<TA, DType> lhs_;
Plan<TB, DType> rhs_;
};
// unary expression
template<typename OP, typename TA, int etype, typename DType>
class Plan<UnaryMapExp<OP, TA, DType, etype>, DType> {
public:
explicit Plan(const Plan<TA, DType> &src) : src_(src) {}
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const {
return OP::Map(src_.Eval(y, x));
}
private:
Plan<TA, DType> src_;
};
// remaps map tensor expression to subtype's plan
template<typename SubType, typename SrcExp, int dim, typename DType>
struct Plan<MakeTensorExp<SubType, SrcExp, dim, DType>, DType> {
public:
Plan(const Plan<SubType, DType> &src) : src_(src) {}
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const {
return src_.Eval(y, x);
}
private:
Plan<SubType, DType> src_;
};
// tranpsoe
template<typename EType, typename DType>
class Plan<TransposeExp<EType, DType>, DType> {
public:
explicit Plan(const Plan<EType, DType> &src) : src_(src) {}
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const {
return src_.Eval(x, y);
}
private:
Plan<EType, DType> src_;
};
//----------------------------------------------------------------------
// Mappings from expression to plans
//---------------------------------------------------------------------
template<typename OP, typename TA, typename TB, typename DType, int etype>
inline Plan<BinaryMapExp<OP, TA, TB, DType, etype>, DType>
MakePlan(const BinaryMapExp<OP, TA, TB, DType, etype> &e);
template<typename OP, typename TA, typename TB, typename TC, typename DType, int etype>
inline Plan<TernaryMapExp<OP, TA, TB, TC, DType, etype>, DType>
MakePlan(const TernaryMapExp<OP, TA, TB, TC, DType, etype> &e);
template<typename DType>
inline Plan<ScalarExp<DType>, DType> MakePlan(const ScalarExp<DType> &e) {
return Plan<ScalarExp<DType>, DType>(e.scalar_);
}
template<typename DstDType, typename SrcDType, typename EType, int etype>
inline Plan<TypecastExp<DstDType, SrcDType, EType, etype>, DstDType>
MakePlan(const TypecastExp<DstDType, SrcDType, EType, etype> &e) {
return Plan<TypecastExp<DstDType, SrcDType, EType, etype>, DstDType>(MakePlan(e.exp));
}
template<typename T, typename DType>
inline Plan<T, DType> MakePlan(const RValueExp<T, DType> &e) {
return Plan<T, DType>(e.self());
}
template<typename T, typename DType>
inline Plan<TransposeExp<T, DType>, DType>
MakePlan(const TransposeExp<T, DType> &e) {
return Plan<TransposeExp<T, DType>, DType>(MakePlan(e.exp));
}
template<typename T, typename SrcExp, int dim, typename DType>
inline Plan<T, DType>
MakePlan(const MakeTensorExp<T, SrcExp, dim, DType> &e) {
return Plan<T, DType>(e.real_self());
}
template<typename OP, typename TA, typename DType, int etype>
inline Plan<UnaryMapExp<OP, TA, DType, etype>, DType>
MakePlan(const UnaryMapExp<OP, TA, DType, etype> &e) {
return Plan<UnaryMapExp<OP, TA, DType, etype>, DType>(MakePlan(e.src_));
}
template<typename OP, typename TA, typename TB, typename DType, int etype>
inline Plan<BinaryMapExp<OP, TA, TB, DType, etype>, DType>
MakePlan(const BinaryMapExp<OP, TA, TB, DType, etype> &e) {
return Plan<BinaryMapExp<OP, TA, TB, DType, etype>,
DType>(MakePlan(e.lhs_), MakePlan(e.rhs_));
}
// Ternary
template<typename OP, typename TA, typename TB, typename TC, typename DType, int etype>
inline Plan<TernaryMapExp<OP, TA, TB, TC, DType, etype>, DType>
MakePlan(const TernaryMapExp<OP, TA, TB, TC, DType, etype> &e) {
return Plan<TernaryMapExp<OP, TA, TB, TC, DType, etype>,
DType>(MakePlan(e.item1_), MakePlan(e.item2_), MakePlan(e.item3_));
}
//----------------------------------------------------------------
// Static Type inference and Type Checking
//----------------------------------------------------------------
/*!
* \brief static type inference template,
* used to get the dimension of each expression,
* if ExpInfo<E>::kDim == -1, this means here are mismatch in expression
* if (ExpInfo<E>::kDevMask & cpu::kDevMask) != 0, this means this expression can be assigned to cpu
* \tparam E expression
*/
template<typename E>
struct ExpInfo {
static const int kDim = -1;
static const int kDevMask = 0;
};
template<typename DType>
struct ExpInfo< ScalarExp<DType> > {
static const int kDim = 0;
static const int kDevMask = 0xffff;
};
template<typename E, typename DType>
struct ExpInfo<TransposeExp<E, DType> > {
static const int kDim = ExpInfo<E>::kDim;
static const int kDevMask = ExpInfo<E>::kDevMask;
};
template<typename DstDType, typename SrcDType, typename EType, int etype>
struct ExpInfo<TypecastExp<DstDType, SrcDType, EType, etype> > {
static const int kDim = ExpInfo<EType>::kDim;
static const int kDevMask = ExpInfo<EType>::kDevMask;
};
template<typename Device, int dim, typename DType>
struct ExpInfo<Tensor<Device, dim, DType> > {
static const int kDim = dim;
static const int kDevMask = Device::kDevMask;
};
template<typename T, typename SrcExp, int dim, typename DType>
struct ExpInfo<MakeTensorExp<T, SrcExp, dim, DType> > {
static const int kDimSrc = ExpInfo<SrcExp>::kDim;
static const int kDim = kDimSrc >= 0 ? dim : -1;
static const int kDevMask = ExpInfo<SrcExp>::kDevMask;
};
template<typename OP, typename TA, typename DType, int etype>
struct ExpInfo<UnaryMapExp<OP, TA, DType, etype> > {
static const int kDim = ExpInfo<TA>::kDim;
static const int kDevMask = ExpInfo<TA>::kDevMask;
};
template<typename OP, typename TA, typename TB, typename DType, int etype>
struct ExpInfo<BinaryMapExp<OP, TA, TB, DType, etype> > {
static const int kDimLhs = ExpInfo<TA>::kDim;
static const int kDimRhs = ExpInfo<TB>::kDim;
static const int kDim = (kDimLhs >= 0 && kDimRhs >= 0) ?\
(kDimLhs == 0 ?\
kDimRhs :\
((kDimRhs == 0 || kDimLhs == kDimRhs) ? kDimLhs : -1)) : -1;
static const int kDevMask = ExpInfo<TA>::kDevMask & ExpInfo<TB>::kDevMask;
};
template<typename OP, typename TA, typename TB, typename TC, typename DType, int etype>
struct ExpInfo<TernaryMapExp<OP, TA, TB, TC, DType, etype> > {
static const int kDimItem1 = ExpInfo<TA>::kDim;
static const int kDimItem2 = ExpInfo<TB>::kDim;
static const int kDimItem3 = ExpInfo<TC>::kDim;
static const int kDim = kDimItem1;
static const int kDevMask = ExpInfo<TA>::kDevMask & ExpInfo<TB>::kDevMask & ExpInfo<TC>::kDevMask;
};
/*! \brief template to do type check */
template<typename Device, int dim, typename DType, typename E>
struct TypeCheck {
/*! \brief dimension of expression*/
static const int kExpDim = ExpInfo<E>::kDim;
/*! \brief whether the expression device type matches */
static const bool kDevPass = (ExpInfo<E>::kDevMask & Device::kDevMask) != 0;
/*! \brief whether the expression can be mapped to expression of dim */
static const bool kMapPass = (kExpDim == 0 || kExpDim == dim) && kDevPass;
/*! \brief whether the expression can be reduced to expression of dim */
static const bool kRedPass = (kExpDim > dim) && kDevPass;
};
/*! \brief used to help static type check*/
template<bool kPass>
struct TypeCheckPass;
// Todo : add static assert using C++11
template<>
struct TypeCheckPass<false> {};
template<>
struct TypeCheckPass<true> {
inline static void Error_All_Tensor_in_Exp_Must_Have_Same_Type(void) {}
inline static void Error_TypeCheck_Not_Pass_For_Reduce_Exp(void) {}
inline static void Error_Expression_Does_Not_Meet_Dimension_Req(void) {}
};
//----------------------------------------------------------------
// Runtime Stream Getting
//----------------------------------------------------------------
template<typename Device, typename E>
struct StreamInfo {
inline static Stream<Device> *Get(const E &t);
};
template<int dim, typename Device, typename DType>
struct StreamInfo<Device, Tensor<Device, dim, DType> > {
inline static Stream<Device> *Get(const Tensor<Device, dim, DType> &t) {
return t.stream_;
}
};
//----------------------------------------------------------------
// Runtime Shape Checking
//----------------------------------------------------------------
/*!
* \brief runtime shape checking template
* get the shape of an expression, report error if shape mismatch
* \tparam dim the dimension of the shape
* \tparam E expression
*/
template<int dim, typename E>
struct ShapeCheck {
inline static Shape<dim> Check(const E &t);
};
template<int dim, typename DType>
struct ShapeCheck<dim, ScalarExp<DType> > {
inline static Shape<dim> Check(const ScalarExp<DType> &exp) {
// use lowest dimension to mark scalar exp
Shape<dim> shape;
for (int i = 0; i < dim; ++i) {
shape[i] = 0;
}
return shape;
}
};
template<int dim, typename DstDType, typename SrcDType, typename EType, int etype>
struct ShapeCheck<dim, TypecastExp<DstDType, SrcDType, EType, etype> > {
inline static Shape<dim>
Check(const TypecastExp<DstDType, SrcDType, EType, etype> &exp) {
return ShapeCheck<dim, EType>::Check(exp.exp);
}
};
template<int dim, typename E, typename DType>
struct ShapeCheck<dim, TransposeExp<E, DType> > {
inline static Shape<dim> Check(const TransposeExp<E, DType> &e) {
// swap the lowest two dimensions
Shape<dim> s = ShapeCheck<dim, E>::Check(e.exp);
std::swap(s[0], s[1]);
return s;
}
};
template<int dim, typename Device, typename DType>
struct ShapeCheck<dim, Tensor<Device, dim, DType> > {
inline static Shape<dim> Check(const Tensor<Device, dim, DType> &t) {
return t.shape_;
}
};
template<int dim, typename SrcExp, typename T, typename DType>
struct ShapeCheck<dim, MakeTensorExp<T, SrcExp, dim, DType> > {
inline static Shape<dim>
Check(const MakeTensorExp<T, SrcExp, dim, DType> &t) {
return t.shape_;
}
};
template<int dim, typename OP, typename TA, typename DType, int etype>
struct ShapeCheck<dim, UnaryMapExp<OP, TA, DType, etype> > {
inline static Shape<dim> Check(const UnaryMapExp<OP, TA, DType, etype> &t) {
Shape<dim> s = ShapeCheck<dim, TA>::Check(t.src_);
return s;
}
};
template<int dim, typename OP, typename TA, typename TB,
typename DType, int etype>
struct ShapeCheck<dim, BinaryMapExp<OP, TA, TB, DType, etype> > {
inline static Shape<dim>
Check(const BinaryMapExp<OP, TA, TB, DType, etype> &t) {
Shape<dim> shape1 = ShapeCheck<dim, TA>::Check(t.lhs_);
Shape<dim> shape2 = ShapeCheck<dim, TB>::Check(t.rhs_);
if (shape1[0] == 0) return shape2;
if (shape2[0] == 0) return shape1;
CHECK_EQ(shape1, shape2) << "BinaryMapExp: Shapes of operands are not the same, " <<
"Shape1=" << shape1 << ", Shape2=" << shape2;
return shape1;
}
};
template<int dim, typename OP, typename TA, typename TB, typename TC,
typename DType, int etype>
struct ShapeCheck<dim, TernaryMapExp<OP, TA, TB, TC, DType, etype> > {
inline static Shape<dim>
Check(const TernaryMapExp<OP, TA, TB, TC, DType, etype> &t) {
Shape<dim> shape1 = ShapeCheck<dim, TA>::Check(t.item1_);
Shape<dim> shape2 = ShapeCheck<dim, TB>::Check(t.item2_);
Shape<dim> shape3 = ShapeCheck<dim, TC>::Check(t.item3_);
bool same = (shape1 == shape2) && (shape2 == shape3);
CHECK(same) << "TernaryMapExp: Shapes of operands are not the same, " <<
"Shape1=" << shape1 << ", Shape2=" << shape2 << ", Shape3=" << shape3;
return shape1;
}
};
} // namespace expr
} // namespace mshadow
// include definition of dot engine
//===== EXPANDING: ../mshadow/mshadow/dot_engine-inl.h =====
/*!
* Copyright (c) 2014 by Contributors
* \file dot_engine-inl.h
* \brief definitions of how Matrix Multiplications can be evaluated
* \author Tianqi Chen
*/
#ifndef MSHADOW_DOT_ENGINE_INL_H_
#define MSHADOW_DOT_ENGINE_INL_H_
//===== EXPANDING: ../mshadow/mshadow/extension/implicit_gemm.h =====
/*!
* Copyright (c) 2014 by Contributors
* \file implicit_gemm.h
* \brief support for implicit GEMM operation
* \author Tianqi Chen
*/
#ifndef MSHADOW_EXTENSION_IMPLICIT_GEMM_H_
#define MSHADOW_EXTENSION_IMPLICIT_GEMM_H_
//===== EXPANDING: ../mshadow/mshadow/packet-inl.h =====
/*!
* Copyright (c) 2014 by Contributors
* \file packet-inl.h
* \brief Generic packet vectorization code
*/
#ifndef MSHADOW_PACKET_INL_H_
#define MSHADOW_PACKET_INL_H_
#ifdef __APPLE__
#else
#endif
namespace mshadow {
/*! \brief namespace of packet math*/
namespace packet {
enum PacketArch {
kPlain,
kSSE2,
};
#if MSHADOW_USE_SSE
#define MSHADOW_DEFAULT_PACKET ::mshadow::packet::kSSE2
#else
#define MSHADOW_DEFAULT_PACKET ::mshadow::packet::kPlain
#endif
// whether packet operator is enabled.
/*!
* \brief Generic packet type
* \tparam DType The data type of the packet.
* \tparam Arch the Arch of the packet.
*/
template<typename DType, PacketArch Arch = MSHADOW_DEFAULT_PACKET>
struct Packet;
template<PacketArch Arch>
struct AlignBytes {
static const index_t value = 4;
};
} // namespace packet
} // namespace mshadow
namespace mshadow {
namespace packet {
/*!
* \brief analog to cudaMallocPitch, allocate a aligned space with num_line * lspace cells
* \param out_pitch output parameter, the actuall space allocated for each line
* \param lspace number of cells required for each line
* \param num_line number of lines to be allocated
*/
inline void* AlignedMallocPitch(size_t *out_pitch,
size_t lspace,
size_t num_line) {
const index_t bits = AlignBytes<MSHADOW_DEFAULT_PACKET>::value;
const index_t mask = (1 << bits) - 1;
size_t pitch = ((lspace + mask) >> bits) << bits;
*out_pitch = pitch;
#ifdef _MSC_VER
void *res = _aligned_malloc(pitch * num_line, 1 << bits);
#else
void *res;
int ret = posix_memalign(&res, 1 << bits, pitch * num_line);
CHECK_EQ(ret, 0) << "AlignedMallocPitch failed";
#endif
if (res == NULL) {
LOG(FATAL) << "AlignedMallocPitch failed";
}
return res;
}
/*!
* \brief free aligned space
* \param ptr pointer to space to be freed
*/
inline void AlignedFree(void *ptr) {
#ifdef _MSC_VER
_aligned_free(ptr);
#else
free(ptr);
#endif
}
/*! \brief check if a pointer is aligned */
template<PacketArch Arch>
inline bool CheckAlign(size_t pitch) {
const index_t bits = AlignBytes<Arch>::value;
return !(pitch & ((1 << bits) - 1));
}
/*! \brief check if a pointer is aligned */
template<PacketArch Arch>
inline bool CheckAlign(void *ptr) {
return CheckAlign<Arch>(reinterpret_cast<size_t>(ptr));
}
/*!
* \brief get upper bound of aligned index of size
* \param size size of the array
* \param fsize size of float
*/
template<typename DType, PacketArch Arch>
inline index_t UpperAlign(index_t size) {
const index_t bits = AlignBytes<MSHADOW_DEFAULT_PACKET>::value;
const index_t mask = (1 << bits) - 1;
const index_t fsize = sizeof(DType);
return (((size * fsize + mask) >> bits) << bits) / fsize;
}
/*!
* \brief get lower bound of aligned index of size
* \param size size of the array
* \param fsize size of float
*/
template<typename DType, PacketArch Arch>
inline index_t LowerAlign(index_t size) {
const index_t bits = AlignBytes<MSHADOW_DEFAULT_PACKET>::value;
const index_t fsize = sizeof(DType);
return (((size * fsize) >> bits) << bits) / fsize;
}
/*!
* \brief generic Packet operator
* \tparam OP The operator
* \tparam DType The data type
* \tparam Arch The architecture.
*/
template<typename OP, typename DType, PacketArch Arch>
struct PacketOp {
static const bool kEnabled = false;
};
// specialization of operators
template<typename DType, PacketArch Arch>
struct PacketOp<op::plus, DType, Arch> {
static const bool kEnabled = true;
MSHADOW_CINLINE static Packet<DType, Arch> Map(const Packet<DType, Arch>& lhs,
const Packet<DType, Arch>& rhs) {
return lhs + rhs;
}
};
template<typename DType, PacketArch Arch>
struct PacketOp<op::minus, DType, Arch> {
static const bool kEnabled = true;
MSHADOW_CINLINE static Packet<DType, Arch> Map(const Packet<DType, Arch>& lhs,
const Packet<DType, Arch>& rhs) {
return lhs - rhs;
}
};
template<typename DType, PacketArch Arch>
struct PacketOp<op::mul, DType, Arch> {
static const bool kEnabled = true;
MSHADOW_CINLINE static Packet<DType, Arch> Map(const Packet<DType, Arch>& lhs,
const Packet<DType, Arch>& rhs) {
return lhs * rhs;
}
};
template<typename DType, PacketArch Arch>
struct PacketOp<op::div, DType, Arch> {
static const bool kEnabled = true;
MSHADOW_CINLINE static Packet<DType, Arch> Map(const Packet<DType, Arch>& lhs,
const Packet<DType, Arch>& rhs) {
return lhs / rhs;
}
};
template<typename DType, PacketArch Arch>
struct PacketOp<op::identity, DType, Arch> {
static const bool kEnabled = true;
MSHADOW_CINLINE static Packet<DType, Arch> Map(const Packet<DType, Arch>& src) {
return src;
}
};
// savers to do storage
template<typename SV, typename TFloat, PacketArch Arch>
struct Saver{
MSHADOW_CINLINE static void Save(TFloat *dst, const Packet<TFloat, Arch>& src) {
Packet<TFloat, Arch> lhs = Packet<TFloat, Arch>::Load(dst);
Packet<TFloat, Arch> ans = PacketOp<typename SV::OPType, TFloat, Arch>::Map(lhs, src);
ans.Store(dst);
}
};
template<typename TFloat, PacketArch Arch>
struct Saver<sv::saveto, TFloat, Arch> {
MSHADOW_CINLINE static void Save(TFloat *dst, const Packet<TFloat, Arch>& src) {
src.Store(dst);
}
};
} // namespace packet
} // namespace mshadow
//===== EXPANDING: ../mshadow/mshadow/packet/plain-inl.h =====
/*!
* Copyright (c) 2014 by Contributors
* \file plain-inl.h
* \brief support of plain packet that use the plain datatype.
*/
#ifndef MSHADOW_PACKET_PLAIN_INL_H_
#define MSHADOW_PACKET_PLAIN_INL_H_
namespace mshadow {
namespace packet {
template<typename DType>
struct Packet<DType, kPlain> {
public:
/*! \brief number of float in vector */
static const index_t kSize = 1;
/*! \brief The internal data */
DType data_;
// enable default copy constructor
Packet(void) {}
// constructor from the intrinsic type
explicit Packet(DType data) : data_(data) {}
// create a fill with the target value s
MSHADOW_CINLINE static Packet<DType, kPlain> Fill(DType s) {
return Packet<DType, kPlain>(s);
}
// load from address
MSHADOW_CINLINE static Packet<DType, kPlain> Load(const DType* src) {
return Packet<DType, kPlain>(*src);
}
// load from address
MSHADOW_CINLINE static Packet<DType, kPlain> LoadUnAligned(const DType* src) {
return Packet<DType, kPlain>(*src);
}
// fill it with value s
MSHADOW_CINLINE Packet<DType, kPlain>& operator=(DType s) {
data_ = s;
return *this;
}
// store data into dst
MSHADOW_CINLINE void Store(DType* dst) const {
*dst = data_;
}
// get the sum of all contents
MSHADOW_CINLINE DType Sum() const {
return data_;
}
};
template<typename DType>
MSHADOW_CINLINE Packet<DType, kPlain> operator+(const Packet<DType, kPlain>& lhs,
const Packet<DType, kPlain>& rhs) {
return Packet<DType, kPlain>(lhs.data_ + rhs.data_);
}
template<typename DType>
MSHADOW_CINLINE Packet<DType, kPlain> operator-(const Packet<DType, kPlain>& lhs,
const Packet<DType, kPlain>& rhs) {
return Packet<DType, kPlain>(lhs.data_ - rhs.data_);
}
template<typename DType>
MSHADOW_CINLINE Packet<DType, kPlain> operator*(const Packet<DType, kPlain>& lhs,
const Packet<DType, kPlain>& rhs) {
return Packet<DType, kPlain>(lhs.data_ * rhs.data_);
}
template<typename DType>
MSHADOW_CINLINE Packet<DType, kPlain> operator/(const Packet<DType, kPlain>& lhs,
const Packet<DType, kPlain>& rhs) {
return Packet<DType, kPlain>(lhs.data_ / rhs.data_);
}
} // namespace packet
} // namespace mshadow
#endif // MSHADOW_PACKET_PLAIN_INL_H_
//===== EXPANDED: ../mshadow/mshadow/packet/plain-inl.h =====
#if MSHADOW_USE_SSE && !defined(__CUDACC__)
//===== EXPANDING: ../mshadow/mshadow/packet/sse-inl.h =====
/*!
* Copyright (c) 2014 by Contributors
* \file sse-inl.h
* \brief support of sse2 packet optimization of some operations
* \author Tianqi Chen
*/
#ifndef MSHADOW_PACKET_SSE_INL_H_
#define MSHADOW_PACKET_SSE_INL_H_
namespace mshadow {
namespace packet {
template<>
struct Packet<float, kSSE2> {
public:
/*! \brief number of float in vector */
static const index_t kSize = 4;
/*! \brief The internal data */
__m128 data_;
// enable default copy constructor
Packet(void) {}
// constructor from the intrinsic type
explicit Packet(__m128 data) : data_(data) {}
// create a fill with the target value s
MSHADOW_CINLINE static Packet<float, kSSE2> Fill(float s) {
return Packet<float, kSSE2>(_mm_set1_ps(s));
}
// load from address
MSHADOW_CINLINE static Packet<float, kSSE2> Load(const float* src) {
return Packet<float, kSSE2>(_mm_load_ps(src));
}
// load from address
MSHADOW_CINLINE static Packet<float, kSSE2> LoadUnAligned(const float* src) {
return Packet<float, kSSE2>(_mm_loadu_ps(src));
}
// fill it with value s
MSHADOW_CINLINE Packet<float, kSSE2>& operator=(float s) {
data_ = _mm_set1_ps(s);
return *this;
}
// store data into dst
MSHADOW_CINLINE void Store(float* dst) const {
_mm_store_ps(dst, data_);
}
// get the sum of all contents
MSHADOW_CINLINE float Sum() const {
__m128 ans = _mm_add_ps(data_, _mm_movehl_ps(data_, data_));
__m128 rst = _mm_add_ss(ans, _mm_shuffle_ps(ans, ans, 1));
#if defined(_MSC_VER) && (_MSC_VER <= 1500) && defined(_WIN64)
return rst.m128_f32[0];
#else
float rr = _mm_cvtss_f32(rst);
return rr;
#endif
}
};
/*! \brief vector real type for float */
template<>
struct Packet<double, kSSE2> {
/*! \brief number of float in vector */
static const index_t kSize = 2;
// internal data
__m128d data_;
// constructor
Packet(void) {}
explicit Packet(__m128d data) : data_(data) {}
// create a fill with the target value s
MSHADOW_CINLINE static Packet<double, kSSE2> Fill(double s) {
return Packet<double, kSSE2>(_mm_set1_pd(s));
}
// load from address
MSHADOW_CINLINE static Packet<double, kSSE2> Load(const double* src) {
return Packet<double, kSSE2>(_mm_load_pd(src));
}
MSHADOW_CINLINE static Packet<double, kSSE2> LoadUnAligned(const double* src) {
return Packet<double, kSSE2>(_mm_loadu_pd(src));
}
// fill it with value s
MSHADOW_CINLINE Packet<double, kSSE2>& operator=(double s) {
data_ = _mm_set1_pd(s);
return *this;
}
// store data into dst
MSHADOW_CINLINE void Store(double* dst) const {
_mm_store_pd(dst, data_);
}
// get sum of all content
inline double Sum(void) const {
__m128d tmp = _mm_add_sd(data_, _mm_unpackhi_pd(data_, data_));
#if defined(_MSC_VER) && (_MSC_VER <= 1500) && defined(_WIN64)
return tmp.m128d_f64[0];
#else
double ans = _mm_cvtsd_f64(tmp);
return ans;
#endif
}
};
MSHADOW_CINLINE Packet<float, kSSE2> operator+(const Packet<float, kSSE2>& lhs,
const Packet<float, kSSE2>& rhs) {
return Packet<float, kSSE2>(_mm_add_ps(lhs.data_, rhs.data_));
}
MSHADOW_CINLINE Packet<double, kSSE2> operator+(const Packet<double, kSSE2>& lhs,
const Packet<double, kSSE2>& rhs) {
return Packet<double, kSSE2>(_mm_add_pd(lhs.data_, rhs.data_));
}
MSHADOW_CINLINE Packet<float, kSSE2> operator-(const Packet<float, kSSE2>& lhs,
const Packet<float, kSSE2>& rhs) {
return Packet<float, kSSE2>(_mm_sub_ps(lhs.data_, rhs.data_));
}
MSHADOW_CINLINE Packet<double, kSSE2> operator-(const Packet<double, kSSE2>& lhs,
const Packet<double, kSSE2>& rhs) {
return Packet<double, kSSE2>(_mm_sub_pd(lhs.data_, rhs.data_));
}
MSHADOW_CINLINE Packet<float, kSSE2> operator*(const Packet<float, kSSE2>& lhs,
const Packet<float, kSSE2>& rhs) {
return Packet<float, kSSE2>(_mm_mul_ps(lhs.data_, rhs.data_));
}
MSHADOW_CINLINE Packet<double, kSSE2> operator*(const Packet<double, kSSE2>& lhs,
const Packet<double, kSSE2>& rhs) {
return Packet<double, kSSE2>(_mm_mul_pd(lhs.data_, rhs.data_));
}
MSHADOW_CINLINE Packet<float, kSSE2> operator/(const Packet<float, kSSE2>& lhs,
const Packet<float, kSSE2>& rhs) {
return Packet<float, kSSE2>(_mm_div_ps(lhs.data_, rhs.data_));
}
MSHADOW_CINLINE Packet<double, kSSE2> operator/(const Packet<double, kSSE2>& lhs,
const Packet<double, kSSE2>& rhs) {
return Packet<double, kSSE2>(_mm_div_pd(lhs.data_, rhs.data_));
}
} // namespace packet
} // namespace mshadow
#endif // MSHADOW_PACKET_SSE_INL_H_
//===== EXPANDED: ../mshadow/mshadow/packet/sse-inl.h =====
#endif
namespace mshadow {
namespace expr {
typedef packet::PacketArch PacketArch;
// same as plan, but use packet
template<typename ExpType, typename DType, PacketArch Arch>
class PacketPlan {
public:
/*!
* \brief evaluate the expression at index [y][x],
* x will be aligned to Packet<DType, Arch>::kSize
*/
MSHADOW_CINLINE packet::Packet<DType, Arch> EvalPacket(index_t y, index_t x) const;
MSHADOW_CINLINE DType Eval(index_t y, index_t x) const;
};
template <typename Device, int dim, typename DType, PacketArch Arch>
class PacketPlan<Tensor<Device, dim, DType>, DType, Arch> {
public:
explicit PacketPlan(const Tensor<Device, dim, DType> &t)
:dptr_(t.dptr_), stride_(t.stride_) {}
MSHADOW_CINLINE packet::Packet<DType, Arch> EvalPacket(index_t y, index_t x) const {
return packet::Packet<DType, Arch>::Load(&dptr_[y * stride_ + x]);
}
MSHADOW_CINLINE DType Eval(index_t y, index_t x) const {
return dptr_[y * stride_ + x];
}
private:
const DType *dptr_;
index_t stride_;
};
template<typename DType, PacketArch Arch>
class PacketPlan<ScalarExp<DType>, DType, Arch> {
public:
explicit PacketPlan(DType scalar) : scalar_(scalar) {}
MSHADOW_CINLINE packet::Packet<DType, Arch> EvalPacket(index_t y, index_t x) const {
return packet::Packet<DType, Arch>::Fill(scalar_);
}
MSHADOW_CINLINE DType Eval(index_t y, index_t x) const {
return scalar_;
}
private:
DType scalar_;
};
template<typename OP, typename TA, typename TB, int etype, typename DType, PacketArch Arch>
class PacketPlan<BinaryMapExp<OP, TA, TB, DType, etype>, DType, Arch> {
public:
PacketPlan(const PacketPlan<TA, DType, Arch> &lhs, const PacketPlan<TB, DType, Arch> &rhs)
: lhs_(lhs), rhs_(rhs) {}
MSHADOW_CINLINE packet::Packet<DType, Arch> EvalPacket(index_t y, index_t x) const {
return packet::PacketOp<OP, DType, Arch>::Map(lhs_.EvalPacket(y, x), rhs_.EvalPacket(y, x));
}
MSHADOW_CINLINE DType Eval(index_t y, index_t x) const {
return OP::Map(lhs_.Eval(y, x), rhs_.Eval(y, x));
}
private:
PacketPlan<TA, DType, Arch> lhs_;
PacketPlan<TB, DType, Arch> rhs_;
};
template<typename OP, typename TA, int etype, typename DType, PacketArch Arch>
class PacketPlan<UnaryMapExp<OP, TA, DType, etype>, DType, Arch> {
public:
PacketPlan(const PacketPlan<TA, DType, Arch> &src) : src_(src) {}
MSHADOW_CINLINE packet::Packet<DType> EvalPacket(index_t y, index_t x) const {
return packet::PacketOp<OP, DType, Arch>::Map(src_.EvalPacket(y, x));
}
MSHADOW_CINLINE DType Eval(index_t y, index_t x) const {
return OP::Map(src_.Eval(y, x));
}
private:
PacketPlan<TA, DType, Arch> src_;
};
template<PacketArch Arch, typename OP, typename TA, typename TB, typename DType, int etype>
inline PacketPlan<BinaryMapExp<OP, TA, TB, DType, etype>, DType, Arch>
MakePacketPlan(const BinaryMapExp<OP, TA, TB, DType, etype> &e);
template<PacketArch Arch, typename DType>
inline PacketPlan<ScalarExp<DType>, DType, Arch> MakePacketPlan(const ScalarExp<DType> &e) {
return PacketPlan<ScalarExp<DType>, DType, Arch>(e.scalar_);
}
template<PacketArch Arch, typename T, typename DType>
inline PacketPlan<T, DType, Arch> MakePacketPlan(const RValueExp<T, DType> &e) {
return PacketPlan<T, DType, Arch>(e.self());
}
template<PacketArch Arch, typename T, int dim, typename DType>
inline PacketPlan<T, DType, Arch>
MakePacketPlan(const MakeTensorExp<T, cpu, dim, DType> &e) {
return PacketPlan<T, DType, Arch>(e.real_self());
}
template<PacketArch Arch, typename OP, typename TA, typename DType, int etype>
inline PacketPlan<UnaryMapExp<OP, TA, DType, etype>, DType, Arch>
MakePacketPlan(const UnaryMapExp<OP, TA, DType, etype> &e) {
return PacketPlan<UnaryMapExp<OP, TA, DType, etype>, DType, Arch>(MakePacketPlan<Arch>(e.src_));
}
template<PacketArch Arch, typename OP, typename TA, typename TB, typename DType, int etype>
inline PacketPlan<BinaryMapExp<OP, TA, TB, DType, etype>, DType, Arch>
MakePacketPlan(const BinaryMapExp<OP, TA, TB, DType, etype> &e) {
return PacketPlan<BinaryMapExp<OP, TA, TB, DType, etype>,
DType, Arch>(MakePacketPlan<Arch>(e.lhs_), MakePacketPlan<Arch>(e.rhs_));
}
/*!
* \brief static check packet enable
*
* \tparam Device the type of Device
* \tparam dim dimension of the tensor
* \tparam E expression
*/
template<typename E, PacketArch Arch>
struct PacketCheck{
static const bool kPass = false;
};
template<PacketArch Arch>
struct PacketCheck<float, Arch> {
static const bool kPass = true;
};
template<PacketArch Arch>
struct PacketCheck<double, Arch> {
static const bool kPass = true;
};
template<typename DType, PacketArch Arch>
struct PacketCheck<ScalarExp<DType>, Arch> {
static const bool kPass = PacketCheck<DType, Arch>::kPass;
};
template<int dim, typename DType, PacketArch Arch>
struct PacketCheck<Tensor<cpu, dim, DType>, Arch> {
static const bool kPass = PacketCheck<DType, Arch>::kPass;
};
template<typename OP, typename TA, typename DType, int etype, PacketArch Arch>
struct PacketCheck<UnaryMapExp<OP, TA, DType, etype>, Arch> {
static const bool kPass = PacketCheck<TA, Arch>::kPass &&
packet::PacketOp<OP, DType, Arch>::kEnabled;
};
template<typename OP, typename TA, typename TB, typename DType, int etype, PacketArch Arch>
struct PacketCheck< BinaryMapExp<OP, TA, TB, DType, etype>, Arch> {
static const bool kPass = packet::PacketOp<OP, DType, Arch>::kEnabled &&
PacketCheck<TA, Arch>::kPass && PacketCheck<TB, Arch>::kPass;
};
//----------------------------------------------------
// Check if data is aligned and allow packet operation
//----------------------------------------------------
template<int dim, typename E, PacketArch Arch>
struct PacketAlignCheck {
inline static bool Check(const E &exp) {
return false;
}
};
template<int dim, typename DType, PacketArch Arch>
struct PacketAlignCheck<dim, ScalarExp<DType>, Arch> {
inline static bool Check(const ScalarExp<DType> &exp) {
return true;
}
};
template<int dim, typename DType, PacketArch Arch>
struct PacketAlignCheck<dim, Tensor<cpu, dim, DType>, Arch> {
inline static bool Check(const Tensor<cpu, dim, DType> &t) {
return packet::CheckAlign<Arch>(t.dptr_) &&
packet::CheckAlign<Arch>(t.stride_ * sizeof(DType));
}
};
template<int dim, typename OP, typename TA, typename DType, int etype, PacketArch Arch>
struct PacketAlignCheck<dim, UnaryMapExp<OP, TA, DType, etype>, Arch> {
inline static bool Check(const UnaryMapExp<OP, TA, DType, etype> &t) {
return PacketAlignCheck<dim, TA, Arch>::Check(t.src_);
}
};
template<int dim, typename OP, typename TA, typename TB,
typename DType, int etype, PacketArch Arch>
struct PacketAlignCheck<dim, BinaryMapExp<OP, TA, TB, DType, etype>, Arch> {
inline static bool Check(const BinaryMapExp<OP, TA, TB, DType, etype> &t) {
return PacketAlignCheck<dim, TA, Arch>::Check(t.lhs_) &&
PacketAlignCheck<dim, TB, Arch>::Check(t.rhs_);
}
};
/*!
* \brief use PacketPlan to compute result
*/
template<typename SV, typename E, int dim, typename DType, PacketArch Arch>
inline void MapPacketPlan(Tensor<cpu, dim, DType> _dst,
const expr::PacketPlan<E, DType, Arch>& plan) {
Tensor<cpu, 2, DType> dst = _dst.FlatTo2D();
const index_t xlen = packet::LowerAlign<DType, Arch>(dst.size(1));
#if (MSHADOW_USE_CUDA == 0)
#pragma omp parallel for
#endif
for (openmp_index_t y = 0; y < dst.size(0); ++y) {
for (index_t x = 0; x < xlen; x += packet::Packet<DType, Arch>::kSize) {
packet::Saver<SV, DType, Arch>::Save(&dst[y][x], plan.EvalPacket(y, x));
}
for (index_t x = xlen; x < dst.size(1); ++x) {
SV::Save(dst[y][x], plan.Eval(y, x));
}
}
}
} // namespace expr
} // namespace mshadow
#endif // MSHADOW_PACKET_INL_H_
//===== EXPANDED: ../mshadow/mshadow/packet-inl.h =====
namespace mshadow {
namespace expr {
/*!
* \brief Matrix multiplication.
* \tparam LhsExp type of lhs expression
* \tparam LhsExp type of rhs expression
* \tparam DType the type of elements
*/
template<typename LhsExp, typename RhsExp, typename DType>
struct ImplicitGEMMExp:
public Exp<ImplicitGEMMExp<LhsExp, RhsExp, DType>,
DType, type::kChainer> {
/*! \brief lhs operand */
const LhsExp &lhs_;
/*! \brief rhs operand */
const RhsExp &rhs_;
/*! \brief internal production size*/
index_t prod_size_;
/*! \brief the shape of this expression */
Shape<2> shape_;
/*! \brief constructor */
ImplicitGEMMExp(const LhsExp &lhs, const RhsExp &rhs)
: lhs_(lhs), rhs_(rhs) {
Shape<2> slhs = ShapeCheck<2, LhsExp>::Check(lhs_);
Shape<2> srhs = ShapeCheck<2, RhsExp>::Check(rhs_);
this->shape_ = mshadow::Shape2(slhs[0], srhs[1]);
prod_size_ = slhs[1];
}
};
template<typename LhsExp, typename RhsExp, typename DType, int e1, int e2>
inline ImplicitGEMMExp<LhsExp, RhsExp, DType>
implicit_dot(const Exp<LhsExp, DType, e1> &lhs,
const Exp<RhsExp, DType, e2> &rhs) {
TypeCheckPass<ExpInfo<LhsExp>::kDim == 2 && ExpInfo<RhsExp>::kDim == 2>
::Error_Expression_Does_Not_Meet_Dimension_Req();
return ImplicitGEMMExp<LhsExp, RhsExp, DType>(lhs.self(), rhs.self());
}
//----------------------
// Execution plan
//----------------------
template<typename LhsExp, typename RhsExp, typename DType>
struct Plan<ImplicitGEMMExp<LhsExp, RhsExp, DType>, DType> {
public:
explicit Plan(const ImplicitGEMMExp<LhsExp, RhsExp, DType> &e)
: lhs_(MakePlan(e.lhs_)),
rhs_(MakePlan(e.rhs_)),
prod_size_(e.prod_size_),
prod_size_lower_align_(packet::LowerAlign<DType, MSHADOW_DEFAULT_PACKET>(e.prod_size_)) {
}
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const {
typedef packet::Packet<DType> Packet;
Packet sum = Packet::Fill(0);
DType lhs_temp[Packet::kSize], rhs_temp[Packet::kSize];
for (index_t i = 0; i < prod_size_lower_align_; i += packet::Packet<DType>::kSize) {
// unroll
for (index_t j = 0; j < Packet::kSize; ++j) {
lhs_temp[j] = lhs_.Eval(y, i + j);
}
for (index_t j = 0; j < Packet::kSize; ++j) {
rhs_temp[j] = rhs_.Eval(i + j, x);
}
sum = sum + Packet::LoadUnAligned(lhs_temp) * Packet::LoadUnAligned(rhs_temp);
}
DType ret_result = sum.Sum();
for (index_t i = prod_size_lower_align_; i < prod_size_; ++i) {
ret_result += lhs_.Eval(y, i) * rhs_.Eval(i, x);
}
return ret_result;
}
private:
expr::Plan<LhsExp, DType> lhs_;
expr::Plan<RhsExp, DType> rhs_;
const index_t prod_size_;
const index_t prod_size_lower_align_;
};
template<typename LhsExp, typename RhsExp, typename DType>
inline Plan<ImplicitGEMMExp<LhsExp, RhsExp, DType>, DType>
MakePlan(const ImplicitGEMMExp<LhsExp, RhsExp, DType> &exp) {
return Plan<ImplicitGEMMExp<LhsExp, RhsExp, DType>, DType>(exp);
}
template<int dim, typename LhsExp, typename RhsExp, typename DType>
struct ShapeCheck<dim, ImplicitGEMMExp<LhsExp, RhsExp, DType> > {
inline static Shape<dim>
Check(const ImplicitGEMMExp<LhsExp, RhsExp, DType> &t) {
CHECK(dim == 2)
<< "ImplicitGEMMExp only support 2 dimension";
Shape<dim> shape1 = ShapeCheck<dim, LhsExp>::Check(t.lhs_);
Shape<dim> shape2 = ShapeCheck<dim, RhsExp>::Check(t.rhs_);
CHECK_EQ(shape1[1], shape2[0])
<< "implicit_dot The matrix shape do not match";
return t.shape_;
}
};
template<typename LhsExp, typename RhsExp, typename DType>
struct ExpInfo<ImplicitGEMMExp<LhsExp, RhsExp, DType> > {
static const int kDim = 2;
static const int kDevMask = ExpInfo<LhsExp>::kDevMask & ExpInfo<RhsExp>::kDevMask;
};
} // namespace expr
} // namespace mshadow
#endif // MSHADOW_EXTENSION_IMPLICIT_GEMM_H_
//===== EXPANDED: ../mshadow/mshadow/extension/implicit_gemm.h =====
#ifdef __CUDACC__
#endif // #ifdef __CUDACC__
namespace mshadow {
/*!
* \brief CPU/GPU: Get a batched view of the src array. dst[i] = src + i * stride
* \param dst 2D pointer
* \param src 1D pointer
* \param num number of batches
* \param stride size of each batch
* \param stream
*/
template<typename Device, typename DType>
inline void GetBatchedView(DType **dst, DType *src, int num, int stride,
Stream<Device> *stream);
template<typename DType>
inline void GetBatchedView(DType **dst, DType *src, int num, int stride,
Stream<cpu> *stream) {
for (int i = 0; i < num; i++) {
dst[i] = src + i * stride;
}
}
#ifdef __CUDACC__
template<typename DType>
inline void GetBatchedView(DType **dst, DType *src, int num, int stride,
Stream<gpu> *stream) {
cuda::GetBatchedView(dst, src, num, stride, stream);
}
#endif // #ifdef __CUDACC__
namespace expr {
//---------------------------------------------------------------------
// Matrix Multiplications, depends on BLAS Engine
//---------------------------------------------------------------------
template<typename SV, typename Device, int ddim, int ldim,
int rdim, bool ltrans, bool rtrans, typename DType>
struct DotEngine {
inline static void Eval(Tensor<Device, ddim, DType> *p_dst,
const Tensor<Device, ldim, DType> &lhs,
const Tensor<Device, rdim, DType> &rhs,
DType scale);
};
// handles the dot, use CblasColMajor
template<typename Device, typename DType = default_real_t>
struct BLASEngine {
inline static bool GetT(bool t) {
return t ? true : false;
}
inline static void SetStream(Stream<Device> *stream) {
}
inline static void gemm(Stream<Device> *stream,
bool transa, bool transb,
int m, int n, int k, DType alpha,
const DType *A, int lda, const DType *B, int ldb,
DType beta, DType *C, int ldc) {
LOG(FATAL) << "Not implmented!";
}
inline static void batched_gemm(Stream<Device> *stream,
bool transa, bool transb,
int m, int n, int k, DType alpha,
const DType *A, int lda, const DType *B, int ldb,
DType beta, DType *C, int ldc, int batch_count,
DType **workspace) {
LOG(FATAL) << "Not implmented!";
}
inline static void gemv(Stream<Device> *stream,
bool trans, int m, int n,
DType alpha, const DType *A, int lda,
const DType *X, int incX,
DType beta, DType *Y, int incY) {
LOG(FATAL) << "Not implmented!";
}
inline static void batched_gemv(Stream<Device> *stream,
bool trans, int m, int n,
DType alpha, const DType *A, int lda,
const DType *X, int incX,
DType beta, DType *Y, int incY, int batch_count) {
LOG(FATAL) << "Not implmented!";
}
inline static void ger(Stream<Device> *stream,
int m, int n, DType alpha,
const DType *X, int incX,
const DType *Y, int incY, DType *A, int lda) {
LOG(FATAL) << "Not implmented!";
}
inline static void batched_ger(Stream<Device> *stream,
int m, int n, DType alpha,
const DType *X, int incX,
const DType *Y, int incY, DType *A, int lda, int batch_count) {
LOG(FATAL) << "Not implmented!";
}
inline static void dot(Stream<Device> *stream,
int n,
const DType* X, int incX,
const DType* Y, int incY,
DType* ret) {
LOG(FATAL) << "Not implmented!";
}
};
#if MSHADOW_STAND_ALONE
template<>
struct BLASEngine<cpu, float> {
inline static bool GetT(bool t) {
return t ? true : false;
}
inline static void SetStream(Stream<cpu> *stream) {
}
inline static void gemm(Stream<cpu> *stream,
bool transa, bool transb,
int m, int n, int k, float alpha,
const float *A, int lda, const float *B, int ldb,
float beta, float *C, int ldc) {
if (alpha == 1.0f && beta == 0.0f) {
bool transpose_left = transb;
bool transpose_right = transa;
Tensor<cpu, 2, float> lhs((float*)B, Shape2(transpose_left ? k : n, transpose_left ? n : k)); // NOLINT(*)
Tensor<cpu, 2, float> rhs((float*)A, Shape2(transpose_right ? m : k, transpose_right ? k : m)); // NOLINT(*)
Tensor<cpu, 2, float> dst(C, Shape2(m, n));
if (!transpose_left && !transpose_right) {
dst = expr::implicit_dot(lhs, rhs); return;
} else if (!transpose_left && transpose_right) {
dst = expr::implicit_dot(lhs, rhs.T()); return;
} else if (transpose_left && !transpose_right) {
dst = expr::implicit_dot(lhs.T(), rhs); return;
} else {
LOG(FATAL) << "Not implmented!";
}
} else {
LOG(FATAL) << "Not implmented!";
}
}
inline static void batched_gemm(Stream<cpu> *stream,
bool transa, bool transb,
int m, int n, int k, float alpha,
const float *A, int lda, const float *B, int ldb,
float beta, float *C, int ldc, int batch_count,
float **workspace) {
for (int i = 0; i < batch_count; ++i) {
gemm(stream, transa, transb, m, n, k, alpha,
A + i * m * k, lda, B + i * k * n, ldb,
beta, C + i * m * n, ldc);
}
}
inline static void gemv(Stream<cpu> *stream,
bool trans, int m, int n,
float alpha, const float *A, int lda,
const float *X, int incX,
float beta, float *Y, int incY) {
LOG(FATAL) << "Not implmented!";
}
inline static void batched_gemv(Stream<cpu> *stream,
bool trans, int m, int n,
float alpha, const float *A, int lda,
const float *X, int incX,
float beta, float *Y, int incY, int batch_count) {
LOG(FATAL) << "Not implmented!";
}
inline static void ger(Stream<cpu> *stream,
int m, int n, float alpha,
const float *X, int incX,
const float *Y, int incY, float *A, int lda) {
LOG(FATAL) << "Not implmented!";
}
inline static void batched_ger(Stream<cpu> *stream,
int m, int n, float alpha,
const float *X, int incX,
const float *Y, int incY, float *A, int lda, int batch_count) {
LOG(FATAL) << "Not implmented!";
}
inline static void dot(Stream<cpu> *stream,
int n,
const float* X, int incX,
const float* Y, int incY,
float* ret) {
LOG(FATAL) << "Not implmented!";
}
};
template<>
struct BLASEngine<cpu, double> {
inline static bool GetT(bool t) {
return t ? true : false;
}
inline static void SetStream(Stream<cpu> *stream) {
}
inline static void gemm(Stream<cpu> *stream,
bool transa, bool transb,
int m, int n, int k, double alpha,
const double *A, int lda, const double *B, int ldb,
double beta, double *C, int ldc) {
if (alpha == 1.0f && beta == 0.0f) {
bool transpose_left = transb;
bool transpose_right = transa;
Tensor<cpu, 2, double> lhs((double*)B, Shape2(transpose_left ? k : n, transpose_left ? n : k)); // NOLINT(*)
Tensor<cpu, 2, double> rhs((double*)A, Shape2(transpose_right ? m : k, transpose_right ? k : m)); // NOLINT(*)
Tensor<cpu, 2, double> dst(C, Shape2(m, n));
if (!transpose_left && !transpose_right) {
dst = expr::implicit_dot(lhs, rhs); return;
} else if (!transpose_left && transpose_right) {
dst = expr::implicit_dot(lhs, rhs.T()); return;
} else if (transpose_left && !transpose_right) {
dst = expr::implicit_dot(lhs.T(), rhs); return;
} else {
LOG(FATAL) << "Not implmented!";
}
} else {
LOG(FATAL) << "Not implmented!";
}
}
inline static void batched_gemm(Stream<cpu> *stream,
bool transa, bool transb,
int m, int n, int k, double alpha,
const double *A, int lda, const double *B, int ldb,
double beta, double *C, int ldc, int batch_count,
double **workspace) {
for (int i = 0; i < batch_count; ++i) {
gemm(stream, transa, transb, m, n, k, alpha,
A + i * m * k, lda, B + i * k * n, ldb,
beta, C + i * m * n, ldc);
}
}
inline static void gemv(Stream<cpu> *stream,
bool trans, int m, int n,
double alpha, const double *A, int lda,
const double *X, int incX,
double beta, double *Y, int incY) {
LOG(FATAL) << "Not implmented!";
}
inline static void batched_gemv(Stream<cpu> *stream,
bool trans, int m, int n,
double alpha, const double *A, int lda,
const double *X, int incX,
double beta, double *Y, int incY, int batch_count) {
LOG(FATAL) << "Not implmented!";
}
inline static void ger(Stream<cpu> *stream,
int m, int n, double alpha,
const double *X, int incX,
const double *Y, int incY, double *A, int lda) {
LOG(FATAL) << "Not implmented!";
}
inline static void batched_ger(Stream<cpu> *stream,
int m, int n, double alpha,
const double *X, int incX,
const double *Y, int incY, double *A, int lda, int batch_count) {
LOG(FATAL) << "Not implmented!";
}
inline static void dot(Stream<cpu> *stream,
int n,
const double* X, int incX,
const double* Y, int incY,
double* ret) {
LOG(FATAL) << "Not implmented!";
}
};
#elif (MSHADOW_USE_MKL || MSHADOW_USE_CBLAS) // NOLINT(*)
template<>
struct BLASEngine<cpu, float> {
inline static CBLAS_TRANSPOSE GetT(bool t) {
return t ? CblasTrans : CblasNoTrans;
}
inline static void SetStream(Stream<cpu> *stream) {
}
inline static void gemm(Stream<cpu> *stream,
bool transa, bool transb,
int m, int n, int k, float alpha,
const float *A, int lda, const float *B, int ldb,
float beta, float *C, int ldc) {
cblas_sgemm(CblasColMajor, GetT(transa), GetT(transb),
m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
}
inline static void batched_gemm(Stream<cpu> *stream,
bool transa, bool transb,
int m, int n, int k, float alpha,
const float *A, int lda, const float *B, int ldb,
float beta, float *C, int ldc, int batch_count,
float **workspace) {
for (int i = 0; i < batch_count; ++i) {
gemm(stream, transa, transb, m, n, k, alpha,
A + i * m * k, lda, B + i * k * n, ldb,
beta, C + i * m * n, ldc);
}
}
inline static void gemv(Stream<cpu> *stream,
bool trans, int m, int n,
float alpha, const float *A, int lda,
const float *X, int incX,
float beta, float *Y, int incY) {
cblas_sgemv(CblasColMajor, GetT(trans), m, n, alpha,
A, lda, X, incX, beta, Y, incY);
}
inline static void batched_gemv(Stream<cpu> *stream,
bool trans, int m, int n,
float alpha, const float *A, int lda,
const float *X, int incX,
float beta, float *Y, int incY, int batch_count) {
for (int i = 0; i < batch_count; ++i) {
gemv(stream, trans, m, n, alpha, A + i * m * n, lda,
X + i * (trans ? m : n) * incX, incX,
beta, Y + i * (trans ? n : m) * incY, incY);
}
}
inline static void ger(Stream<cpu> *stream,
int m, int n, float alpha,
const float *X, int incX,
const float *Y, int incY, float *A, int lda) {
cblas_sger(CblasColMajor, m, n, alpha, X, incX, Y, incY, A, lda);
}
inline static void batched_ger(Stream<cpu> *stream,
int m, int n, float alpha,
const float *X, int incX,
const float *Y, int incY, float *A, int lda, int batch_count) {
for (int i = 0; i < batch_count; ++i) {
ger(stream, m, n, alpha, X + i * m * incX, incX, Y + i * n * incY, incY,
A + i * lda * n, lda);
}
}
inline static void dot(Stream<cpu> *stream,
int n,
const float* X, int incX,
const float* Y, int incY,
float* ret) {
*ret = cblas_sdot(n, X, incX, Y, incY);
}
};
template<>
struct BLASEngine<cpu, double> {
inline static CBLAS_TRANSPOSE GetT(bool t) {
return t ? CblasTrans : CblasNoTrans;
}
inline static void SetStream(Stream<cpu> *stream) {
}
inline static void gemm(Stream<cpu> *stream,
bool transa, bool transb,
int m, int n, int k, double alpha,
const double *A, int lda, const double *B, int ldb,
double beta, double *C, int ldc) {
cblas_dgemm(CblasColMajor, GetT(transa), GetT(transb),
m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
}
inline static void batched_gemm(Stream<cpu> *stream,
bool transa, bool transb,
int m, int n, int k, double alpha,
const double *A, int lda, const double *B, int ldb,
double beta, double *C, int ldc, int batch_count,
double **workspace) {
for (int i = 0; i < batch_count; ++i) {
gemm(stream, transa, transb, m, n, k, alpha,
A + i * m * k, lda, B + i * k * n, ldb,
beta, C + i * m * n, ldc);
}
}
inline static void gemv(Stream<cpu> *stream,
bool trans, int m, int n, double alpha,
const double *A, int lda,
const double *X, int incX,
double beta, double *Y, int incY) {
cblas_dgemv(CblasColMajor, GetT(trans), m, n, alpha,
A, lda, X, incX, beta, Y, incY);
}
inline static void batched_gemv(Stream<cpu> *stream,
bool trans, int m, int n,
double alpha, const double *A, int lda,
const double *X, int incX,
double beta, double *Y, int incY, int batch_count) {
for (int i = 0; i < batch_count; ++i) {
gemv(stream, trans, m, n, alpha, A + i * m * n, lda,
X + i * (trans ? m : n) * incX, incX,
beta, Y + i * (trans ? n : m) * incY, incY);
}
}
inline static void ger(Stream<cpu> *stream,
int m, int n, double alpha,
const double *X, int incX,
const double *Y, int incY, double *A, int lda) {
cblas_dger(CblasColMajor, m, n, alpha, X, incX, Y, incY, A, lda);
}
inline static void batched_ger(Stream<cpu> *stream,
int m, int n, double alpha,
const double *X, int incX,
const double *Y, int incY, double *A, int lda, int batch_count) {
for (int i = 0; i < batch_count; ++i) {
ger(stream, m, n, alpha, X + i * m * incX, incX, Y + i * n * incY, incY,
A + i * lda * n, lda);
}
}
inline static void dot(Stream<cpu> *stream,
int n,
const double* X, int incX,
const double* Y, int incY,
double* ret) {
*ret = cblas_ddot(n, X, incX, Y, incY);
}
};
#endif // MSHADOW_USE_CBLAS || MSHADOW_USE_MKL || MSHADOW_STAND_ALONE
// CuBLAS redirect code
#if MSHADOW_USE_CUDA
// All CuBLAS goes to here, use legacy API: not threadsafe
template<>
struct BLASEngine<gpu, half::half_t> {
inline static cublasOperation_t GetT(bool t) {
return t ? CUBLAS_OP_T : CUBLAS_OP_N;
}
inline static void SetStream(Stream<gpu> *stream) {
cublasStatus_t err = cublasSetStream(Stream<gpu>::GetBlasHandle(stream),
Stream<gpu>::GetStream(stream));
CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas set stream fail";
}
inline static void gemm(Stream<gpu> *stream,
bool transa, bool transb,
int m, int n, int k, half::half_t alpha,
const half::half_t *A, int lda,
const half::half_t *B, int ldb, half::half_t beta,
half::half_t *C, int ldc) {
#if defined(CUDA_VERSION) && CUDA_VERSION >= 7050
#if MSHADOW_USE_PASCAL == 1
cublasStatus_t err = cublasHgemm(Stream<gpu>::GetBlasHandle(stream),
GetT(transa), GetT(transb), m, n, k, &alpha.cuhalf_,
&A->cuhalf_, lda, &B->cuhalf_, ldb, &beta.cuhalf_, &C->cuhalf_, ldc);
CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas Hgemm fail";
#else
float alpha_f = float(alpha); // NOLINT(*)
float beta_f = float(beta); // NOLINT(*)
#if CUDA_VERSION >= 8000
cublasStatus_t err = cublasSgemmEx(Stream<gpu>::GetBlasHandle(stream),
GetT(transa), GetT(transb), m, n, k, &alpha_f,
A, CUDA_R_16F, lda, B, CUDA_R_16F,
ldb, &beta_f, C, CUDA_R_16F, ldc);
CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas SgemmEx fail";
#else
cublasStatus_t err = cublasSgemmEx(Stream<gpu>::GetBlasHandle(stream),
GetT(transa), GetT(transb), m, n, k, &alpha_f,
A, CUBLAS_DATA_HALF, lda, B, CUBLAS_DATA_HALF,
ldb, &beta_f, C, CUBLAS_DATA_HALF, ldc);
CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas SgemmEx fail";
#endif // CUDA_VERSION >= 8000
#endif // MSHADOW_USE_PASCAL == 1
#else
LOG(FATAL) << "Require CUDA version >= 7.5!";
#endif // defined(CUDA_VERSION) && CUDA_VERSION >= 7050
}
inline static void batched_gemm(Stream<gpu> *stream,
bool transa, bool transb,
int m, int n, int k, half::half_t alpha,
const half::half_t *A, int lda, const half::half_t *B, int ldb,
half::half_t beta, half::half_t *C, int ldc, int batch_count,
half::half_t **workspace) {
for (int i = 0; i < batch_count; ++i) {
gemm(stream, transa, transb, m, n, k, alpha,
A + i * m * k, lda, B + i * k * n, ldb,
beta, C + i * m * n, ldc);
}
}
inline static void gemv(Stream<gpu> *stream,
bool trans, int m, int n, half::half_t alpha,
const half::half_t *A, int lda,
const half::half_t *X, int incX, half::half_t beta,
half::half_t *Y, int incY) {
LOG(FATAL) << "Not implmented!";
}
inline static void batched_gemv(Stream<gpu> *stream,
bool trans, int m, int n,
half::half_t alpha, const half::half_t *A, int lda,
const half::half_t *X, int incX,
half::half_t beta, half::half_t *Y, int incY, int batch_count) {
LOG(FATAL) << "Not implmented!";
}
inline static void ger(Stream<gpu> *stream,
int m, int n, half::half_t alpha,
const half::half_t *X, int incX,
const half::half_t *Y, int incY, half::half_t *A, int lda) {
LOG(FATAL) << "Not implmented!";
}
inline static void batched_ger(Stream<gpu> *stream,
int m, int n, half::half_t alpha,
const half::half_t *X, int incX, const half::half_t *Y, int incY,
half::half_t *A, int lda, int batch_count) {
LOG(FATAL) << "Not implmented!";
}
inline static void dot(Stream<gpu> *stream,
int n,
const half::half_t* X, int incX,
const half::half_t* Y, int incY,
half::half_t *ret) {
LOG(FATAL) << "Not implmented!";
}
};
template<>
struct BLASEngine<gpu, float> {
inline static cublasOperation_t GetT(bool t) {
return t ? CUBLAS_OP_T : CUBLAS_OP_N;
}
inline static void SetStream(Stream<gpu> *stream) {
cublasStatus_t err = cublasSetStream(Stream<gpu>::GetBlasHandle(stream),
Stream<gpu>::GetStream(stream));
CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: set stream fail";
}
inline static void gemm(Stream<gpu> *stream,
bool transa, bool transb,
int m, int n, int k, float alpha,
const float *A, int lda,
const float *B, int ldb, float beta,
float *C, int ldc) {
cublasStatus_t err = cublasSgemm(Stream<gpu>::GetBlasHandle(stream),
GetT(transa), GetT(transb), m, n, k, &alpha,
A, lda, B, ldb, &beta, C, ldc);
CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: Sgemm fail";
}
inline static void batched_gemm(Stream<gpu> *stream,
bool transa, bool transb,
int m, int n, int k, float alpha,
const float *A, int lda, const float *B, int ldb,
float beta, float *C, int ldc, int batch_count,
float **workspace) {
#if defined(__CUDACC__) && CUDA_VERSION >= 4010
// Cast DType* to DType** using workspace as a buffer
bool alloc_workspace = false;
if (workspace == NULL) {
// Allocate the workspace if it's NULL.
// TODO(sxjscience) Try to move the allocation inside Tensor, which is thread-safe.
cudaMalloc(reinterpret_cast<void**>(&workspace), 3 * batch_count * sizeof(float*));
alloc_workspace = true;
}
GetBatchedView(workspace, const_cast<float*>(A), batch_count, m * k, stream);
GetBatchedView(workspace + batch_count,
const_cast<float*>(B), batch_count, k * n, stream);
GetBatchedView(workspace + 2 * batch_count, C, batch_count, m * n, stream);
cublasStatus_t err = cublasSgemmBatched(Stream<gpu>::GetBlasHandle(stream),
GetT(transa), GetT(transb), m, n, k, &alpha,
(const float**)workspace, lda,
(const float**)(workspace + batch_count), ldb,
&beta, workspace + 2 * batch_count, ldc, batch_count);
CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: SgemmBatched fail";
if (alloc_workspace) {
cudaFree(workspace);
}
#else
for (int i = 0; i < batch_count; ++i) {
gemm(stream, transa, transb, m, n, k, alpha,
A + i * m * k, lda, B + i * k * n, ldb,
beta, C + i * m * n, ldc);
}
#endif // defined(__CUDACC__) && CUDA_VERSION >= 4010
}
inline static void gemv(Stream<gpu> *stream,
bool trans, int m, int n, float alpha,
const float *A, int lda,
const float *X, int incX, float beta,
float *Y, int incY) {
cublasStatus_t err = cublasSgemv(Stream<gpu>::GetBlasHandle(stream),
GetT(trans), m, n, &alpha, A, lda, X, incX, &beta, Y, incY);
CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: Sgemv fail";
}
inline static void batched_gemv(Stream<gpu> *stream,
bool trans, int m, int n,
float alpha, const float *A, int lda,
const float *X, int incX,
float beta, float *Y, int incY, int batch_count) {
for (int i = 0; i < batch_count; ++i) {
gemv(stream, trans, m, n, alpha, A + i * m * n, lda,
X + i * (trans ? m : n) * incX, incX,
beta, Y + i * (trans ? n : m) * incY, incY);
}
}
inline static void ger(Stream<gpu> *stream,
int m, int n, float alpha,
const float *X, int incX,
const float *Y, int incY, float *A, int lda) {
cublasStatus_t err = cublasSger(Stream<gpu>::GetBlasHandle(stream),
m, n, &alpha, X, incX, Y, incY, A, lda);
CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: Sger fail";
}
inline static void batched_ger(Stream<gpu> *stream,
int m, int n, float alpha,
const float *X, int incX,
const float *Y, int incY, float *A, int lda, int batch_count) {
for (int i = 0; i < batch_count; ++i) {
ger(stream, m, n, alpha, X + i * m * incX, incX, Y + i * n * incY, incY,
A + i * lda * n, lda);
}
}
inline static void dot(Stream<gpu> *stream,
int n,
const float* X, int incX,
const float* Y, int incY,
float *ret) {
cublasSetPointerMode(Stream<gpu>::GetBlasHandle(stream),
CUBLAS_POINTER_MODE_DEVICE);
cublasStatus_t err = cublasSdot(Stream<gpu>::GetBlasHandle(stream),
n, X, incX, Y, incY, ret);
CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: Dot fail";
cublasSetPointerMode(Stream<gpu>::GetBlasHandle(stream),
CUBLAS_POINTER_MODE_HOST);
}
};
template<>
struct BLASEngine<gpu, double> {
inline static cublasOperation_t GetT(bool t) {
return t ? CUBLAS_OP_T : CUBLAS_OP_N;
}
inline static void SetStream(Stream<gpu> *stream) {
cublasStatus_t err = cublasSetStream(Stream<gpu>::GetBlasHandle(stream),
Stream<gpu>::GetStream(stream));
CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: set stream fail";
}
inline static void gemm(Stream<gpu> *stream,
bool transa, bool transb,
int m, int n, int k, double alpha,
const double *A, int lda,
const double *B, int ldb,
double beta, double *C, int ldc) {
cublasStatus_t err = cublasDgemm(Stream<gpu>::GetBlasHandle(stream),
GetT(transa), GetT(transb), m, n, k, &alpha,
A, lda, B, ldb, &beta, C, ldc);
CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: Dgemm fail";
}
inline static void batched_gemm(Stream<gpu> *stream,
bool transa, bool transb,
int m, int n, int k, double alpha,
const double *A, int lda, const double *B, int ldb,
double beta, double *C, int ldc, int batch_count,
double **workspace) {
#if defined(__CUDACC__) && CUDA_VERSION >= 4010
// Cast DType* to DType** using workspace as a buffer
bool alloc_workspace = false;
if (workspace == NULL) {
// Allocate the workspace if it's NULL.
// TODO(sxjscience) Try to move the allocation inside Tensor, which is thread-safe.
cudaMalloc(reinterpret_cast<void**>(&workspace), 3 * batch_count * sizeof(double*));
alloc_workspace = true;
}
GetBatchedView(workspace, const_cast<double*>(A), batch_count, m * k, stream);
GetBatchedView(workspace + batch_count,
const_cast<double*>(B), batch_count, k * n, stream);
GetBatchedView(workspace + 2 * batch_count, C, batch_count, m * n, stream);
cublasStatus_t err = cublasDgemmBatched(Stream<gpu>::GetBlasHandle(stream),
GetT(transa), GetT(transb), m, n, k, &alpha,
(const double**)workspace, lda,
(const double**)(workspace + batch_count), ldb,
&beta, workspace + 2 * batch_count, ldc, batch_count);
CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: DgemmBatched fail";
if (alloc_workspace) {
cudaFree(workspace);
}
#else
for (int i = 0; i < batch_count; ++i) {
gemm(stream, transa, transb, m, n, k, alpha,
A + i * m * k, lda, B + i * k * n, ldb,
beta, C + i * m * n, ldc);
}
#endif // defined(__CUDACC__) && CUDA_VERSION >= 4010
}
inline static void gemv(Stream<gpu> *stream,
bool trans, int m, int n, double alpha,
const double *A, int lda,
const double *X, int incX,
double beta, double *Y, int incY) {
cublasStatus_t err = cublasDgemv(Stream<gpu>::GetBlasHandle(stream),
GetT(trans), m, n, &alpha, A, lda, X, incX, &beta, Y, incY);
CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: Dgemv fail";
}
inline static void batched_gemv(Stream<gpu> *stream,
bool trans, int m, int n,
double alpha, const double *A, int lda,
const double *X, int incX,
double beta, double *Y, int incY, int batch_count) {
for (int i = 0; i < batch_count; ++i) {
gemv(stream, trans, m, n, alpha, A + i * m * n, lda,
X + i * (trans ? m : n) * incX, incX,
beta, Y + i * (trans ? n : m) * incY, incY);
}
}
inline static void ger(Stream<gpu> *stream,
int m, int n, double alpha,
const double *X, int incX,
const double *Y, int incY, double *A, int lda) {
cublasStatus_t err = cublasDger(Stream<gpu>::GetBlasHandle(stream),
m, n, &alpha, X, incX, Y, incY, A, lda);
CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: Dger fail";
}
inline static void batched_ger(Stream<gpu> *stream,
int m, int n, double alpha,
const double *X, int incX,
const double *Y, int incY, double *A, int lda, int batch_count) {
for (int i = 0; i < batch_count; ++i) {
ger(stream, m, n, alpha, X + i * m * incX, incX, Y + i * n * incY, incY,
A + i * lda * n, lda);
}
}
inline static void dot(Stream<gpu> *stream,
int n,
const double* X, int incX,
const double* Y, int incY,
double *ret) {
cublasSetPointerMode(Stream<gpu>::GetBlasHandle(stream),
CUBLAS_POINTER_MODE_DEVICE);
cublasStatus_t err = cublasDdot(Stream<gpu>::GetBlasHandle(stream),
n, X, incX, Y, incY, ret);
CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: Dot fail";
cublasSetPointerMode(Stream<gpu>::GetBlasHandle(stream),
CUBLAS_POINTER_MODE_HOST);
}
};
#endif // MSHADOW_USE_CUDA
// helper function to decide which shape we are in
inline Shape<2> GetShape(const Shape<2> &shape, bool transpose) {
return transpose ? Shape2(shape[1], shape[0]) : shape;
}
// dst = dot(lhs[.T], rhs[.T])
template<typename SV, typename xpu,
bool transpose_left, bool transpose_right, typename DType>
struct DotEngine<SV, xpu, 2, 2, 2, transpose_left, transpose_right, DType> {
inline static void Eval(Tensor<xpu, 2, DType> *p_dst,
const Tensor<xpu, 2, DType> &lhs,
const Tensor<xpu, 2, DType> &rhs,
DType scale) {
Tensor<xpu, 2, DType> &dst = *p_dst;
#if MSHADOW_STAND_ALONE
if (xpu::kDevMask == cpu::kDevMask && scale == 1.0f) {
if (!transpose_left && !transpose_right) {
dst = expr::implicit_dot(lhs, rhs); return;
} else if (!transpose_left && transpose_right) {
dst = expr::implicit_dot(lhs, rhs.T()); return;
} else if (transpose_left && !transpose_right) {
dst = expr::implicit_dot(lhs.T(), rhs); return;
}
}
#endif
// set kernel stream
// if there is no stream, crush
BLASEngine<xpu, DType>::SetStream(dst.stream_);
Shape<2> sleft = GetShape(lhs.shape_, transpose_left);
Shape<2> sright = GetShape(rhs.shape_, transpose_right);
CHECK(dst.size(0) == sleft[0] && dst.size(1) == sright[1] && sleft[1] == sright[0])
<< "dot-gemm: matrix shape mismatch";
// use column major argument to compatible with most BLAS
BLASEngine<xpu, DType>::gemm
(dst.stream_,
transpose_right , transpose_left,
transpose_right ? rhs.size(0) : rhs.size(1),
transpose_left ? lhs.size(1) : lhs.size(0),
transpose_right ? rhs.size(1) : rhs.size(0),
DType(scale * SV::AlphaBLAS()),
rhs.dptr_, rhs.stride_,
lhs.dptr_, lhs.stride_,
DType(SV::BetaBLAS()),
dst.dptr_, dst.stride_);
}
};
template<typename SV, typename xpu, bool transpose_right, typename DType>
struct DotEngine<SV, xpu, 1, 1, 2, false, transpose_right, DType> {
inline static void Eval(Tensor<xpu, 1, DType> *p_dst,
const Tensor<xpu, 1, DType> &lhs,
const Tensor<xpu, 2, DType> &rhs,
DType scale) {
Tensor<xpu, 1, DType> &dst = *p_dst;
// set kernel stream
// if there is no stream, crush
BLASEngine<xpu, DType>::SetStream(dst.stream_);
Shape<2> sright = GetShape(rhs.shape_, transpose_right);
CHECK(dst.size(0) == sright[1] && lhs.size(0) == sright[0])
<< "dot-gemv: matrix shape mismatch"
<< "dst: " << dst.shape_ << "\n"
<< "lhs: " << lhs.shape_ << "\n"
<< "rhs: " << sright << "\n";
BLASEngine<xpu, DType>::gemv
(dst.stream_,
transpose_right,
rhs.size(1), rhs.size(0), scale * SV::AlphaBLAS(),
rhs.dptr_, rhs.stride_,
lhs.dptr_, 1, SV::BetaBLAS(),
dst.dptr_, 1);
}
};
template<typename SV, typename xpu, typename DType>
struct DotEngine<SV, xpu, 2, 1, 1, true, false, DType> {
inline static void Eval(Tensor<xpu, 2, DType> *p_dst,
const Tensor<xpu, 1, DType> &lhs,
const Tensor<xpu, 1, DType> &rhs,
DType scale) {
Tensor<xpu, 2, DType> &dst = *p_dst;
// set kernel stream
// if there is no stream, crush
BLASEngine<xpu, DType>::SetStream(dst.stream_);
CHECK(dst.size(0) == lhs.size(0) && dst.size(1) == rhs.size(0))
<< "dot-ger: matrix shape mismatch"
<< "dst: " << dst.shape_ << "\n"
<< "lhs: " << lhs.shape_ << "\n"
<< "rhs: " << rhs.shape_;
if (SV::BetaBLAS() == 0.0f) {
BLASEngine<xpu, DType>::ger
(dst.stream_, rhs.size(0), lhs.size(0), scale * SV::AlphaBLAS(),
rhs.dptr_, 1, lhs.dptr_, 1, dst.dptr_, dst.stride_);
} else {
DotEngine<SV, xpu, 2, 2, 2, true, false,
DType>::Eval(p_dst, lhs.FlatTo2D(), rhs.FlatTo2D(), scale);
}
}
};
} // namespace expr
} // namespace mshadow
#endif // MSHADOW_DOT_ENGINE_INL_H_
//===== EXPANDED: ../mshadow/mshadow/dot_engine-inl.h =====
namespace mshadow {
namespace expr {
/*! \brief some engine that evaluate complex expression */
template<typename SV, typename RV, typename E, typename DType>
struct ExpComplexEngine {
inline static void Eval(RV *dst, const E &exp);
};
/*! \brief the engine that dispatches simple operations*/
template<typename SV, typename RV, typename DType>
struct ExpEngine {
template<typename E>
inline static void Eval(RV *dst,
const Exp<E, DType, type::kMapper> &exp) {
MapExp<SV>(dst, exp);
}
template<typename E>
inline static void Eval(RV *dst,
const Exp<E, DType, type::kChainer> &exp) {
MapExp<SV>(dst, exp);
}
template<typename E>
inline static void Eval(RV *dst,
const Exp<E, DType, type::kRValue> &exp) {
MapExp<SV>(dst, exp);
}
template<typename E>
inline static void Eval(RV *dst,
const Exp<E, DType, type::kComplex> &exp) {
ExpComplexEngine<SV, RV, E, DType>::Eval(dst->ptrself(), exp.self());
}
};
template<typename SV, typename Device, int dim, int ldim,
int rdim, bool ltrans, bool rtrans, typename DType>
struct ExpComplexEngine<SV,
Tensor<Device, dim, DType>,
DotExp<Tensor<Device, ldim, DType>,
Tensor<Device, rdim, DType>,
ltrans, rtrans, DType>,
DType> {
inline static void Eval(Tensor<Device, dim, DType> *dst,
const DotExp<Tensor<Device, ldim, DType>,
Tensor<Device, rdim, DType>,
ltrans, rtrans, DType> &exp) {
DotEngine<SV, Device, dim, ldim, rdim,
ltrans, rtrans, DType>::Eval(dst, exp.lhs_, exp.rhs_, exp.scale_);
}
};
} // namespace expr
} // namespace mshadow
#endif // MSHADOW_EXPR_ENGINE_INL_H_
//===== EXPANDED: ../mshadow/mshadow/expr_engine-inl.h =====
//===== EXPANDING: ../mshadow/mshadow/extension/broadcast.h =====
/*!
* Copyright (c) 2014 by Contributors
* \file broadcast.h
* \brief support for broadcast and repmat
* \author Tianqi Chen
*/
#ifndef MSHADOW_EXTENSION_BROADCAST_H_
#define MSHADOW_EXTENSION_BROADCAST_H_
namespace mshadow {
namespace expr {
/*!
* \brief broadcast Tensor1D into a higher dimension Tensor
* input: Tensor<Device,1>: ishape[0]
* output: Tensor<Device,dimdst> : oshape[dimcast] = ishape[0]
* \tparam SrcExp type of input expression
* \tparam DType the type of elements
* \tparam dimdst target tensor dimension
* \tparam dimcast_m_dst dimdst - dimcast
*/
template<typename SrcExp, typename DType, int dimdst, int dimdst_m_cast>
struct Broadcast1DExp:
public MakeTensorExp<Broadcast1DExp<SrcExp, DType, dimdst, dimdst_m_cast>,
SrcExp, dimdst, DType> {
/*! \brief source operand */
const SrcExp &src_;
/*! \brief constructor */
Broadcast1DExp(const SrcExp &src, Shape<dimdst> shape)
: src_(src) {
this->shape_ = shape;
}
};
/*!
* \brief broadcast scalar into a higher dimension Tensor
* input: Tensor<Device,1>: ishape = {1}
* output: Tensor<Device, dimdst> : oshape[dimcast] = ishape[0]
* \tparam SrcExp type of input expression
* \tparam DType the type of elements
* \tparam dimdst target tensor dimension
*/
template<typename SrcExp, typename DType, int dimdst>
struct BroadcastScalarExp:
public MakeTensorExp<BroadcastScalarExp<SrcExp, DType, dimdst>,
SrcExp, dimdst, DType> {
/*! \brief source operand */
const SrcExp &src_;
/*! \brief constructor */
BroadcastScalarExp(const SrcExp &src, Shape<dimdst> shape)
: src_(src) {
this->shape_ = shape;
}
};
/*!
* \brief a expression that replicate a 1 dimension tensor in dimension dimcast
* \param src Tensor<Device,1>: shape[0]
* \param shape shape of output
* \return a expresion with type Tensor<Device,dimdst>
* \tparam dimcast target dimension where the 1D tensor will be broadcasted
* \tparam SrcExp type of input expression
* \tparam DType the type of elements
* \tparam dimdst dimension of destination tensor
* \tparam dimcast_lowest the dimension we want to cast the data into
*/
template<int dimcast, typename SrcExp, typename DType,
int etype, int dimdst>
inline Broadcast1DExp<SrcExp, DType, dimdst, dimdst - dimcast>
broadcast(const expr::Exp<SrcExp, DType, etype> &src, Shape<dimdst> shape) {
TypeCheckPass<dimcast < dimdst && ExpInfo<SrcExp>::kDim == 1>
::Error_Expression_Does_Not_Meet_Dimension_Req();
typedef ShapeCheck<1, SrcExp> ShapeCheckDim1SrcExp;
CHECK_EQ(ShapeCheckDim1SrcExp::Check(src.self())[0], shape[dimcast])
<< "broadcast, shape mismatch";
return Broadcast1DExp<SrcExp, DType, dimdst,
dimdst - dimcast>(src.self(), shape);
}
/*!
* \brief a expression that replicate a scalar tensor to target dimension.
* \param src Tensor<Device,1>: shape[0] == 1
* \param shape shape of output
* \return a expresion with type Tensor<Device, dimdst>
* \tparam dimcast target dimension where the 1D tensor will be broadcasted
* \tparam SrcExp type of input expression
* \tparam DType the type of elements
* \tparam dimdst dimension of destination tensor
*/
template<typename SrcExp, typename DType, int etype, int dimdst>
inline BroadcastScalarExp<SrcExp, DType, dimdst>
broadcast_scalar(const expr::Exp<SrcExp, DType, etype> &src, Shape<dimdst> shape) {
TypeCheckPass<ExpInfo<SrcExp>::kDim == 1>
::Error_Expression_Does_Not_Meet_Dimension_Req();
typedef ShapeCheck<1, SrcExp> ShapeCheckDim1SrcExp;
CHECK_EQ(ShapeCheckDim1SrcExp::Check(src.self())[0], 1)
<< "broadcast_scalar, source need to be scalar expression";
return BroadcastScalarExp<SrcExp, DType, dimdst>(src.self(), shape);
}
// short cut functions
/*!
* \brief a expression that replicate a 1 dimension tensor for nrow times
* \param src Tensor<Device,1>: shape[0]
* \param nrow number of rows to replicate
* \return a expresion with type Tensor<Device,2> size(1), size(0) = nrow
* \tparam Device which device it lies
*/
template<typename SrcExp, typename DType, int etype>
inline Broadcast1DExp<SrcExp, DType, 2, 1>
repmat(const expr::Exp<SrcExp, DType, etype> &src, index_t nrow) {
return broadcast<1>
(src, Shape2(nrow, ShapeCheck<1, SrcExp>::Check(src.self())[0]));
}
//----------------------
// Execution plan
//----------------------
template<typename SrcExp, typename DType, int dimdst, int dimdst_m_cast>
struct Plan<Broadcast1DExp<SrcExp, DType, dimdst, dimdst_m_cast>, DType> {
public:
static const int dimcast = dimdst - dimdst_m_cast;
explicit Plan(const Broadcast1DExp<SrcExp, DType, dimdst, dimdst_m_cast> &e)
: src_(MakePlan(e.src_)),
ystride_(e.shape_.ProdShape(dimcast + 1, dimdst - 1)),
length_(e.shape_[dimcast]) {
TypeCheckPass<dimcast != dimdst - 1>
::Error_Expression_Does_Not_Meet_Dimension_Req();
}
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const {
return src_.Eval(0, (y / ystride_) % length_);
}
private:
expr::Plan<SrcExp, DType> src_;
const index_t ystride_, length_;
};
/*! \brief execution plan of Broadcast1DExp */
template<typename SrcExp, typename DType, int dimdst>
struct Plan<Broadcast1DExp<SrcExp, DType, dimdst, 1>, DType>{
public:
explicit Plan(const Broadcast1DExp<SrcExp, DType, dimdst, 1> &e)
: src_(MakePlan(e.src_)) {}
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const {
return src_.Eval(0, x);
}
private:
expr::Plan<SrcExp, DType> src_;
};
/*! \brief execution plan of Broadcast1DExp */
template<typename SrcExp, typename DType, int dimdst>
struct Plan<BroadcastScalarExp<SrcExp, DType, dimdst>, DType>{
public:
explicit Plan(const BroadcastScalarExp<SrcExp, DType, dimdst> &e)
: src_(MakePlan(e.src_)) {}
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const {
return src_.Eval(0, 0);
}
private:
expr::Plan<SrcExp, DType> src_;
};
} // namespace expr
} // namespace mshadow
#endif // MSHADOW_EXTENSION_BROADCAST_H_
//===== EXPANDED: ../mshadow/mshadow/extension/broadcast.h =====
//===== EXPANDING: ../mshadow/mshadow/extension/unpack_patch2col.h =====
/*!
* Copyright (c) 2014 by Contributors
* \file unpack_patch2col.h
* \brief support for unpack
* \author Tianqi Chen
*/
#ifndef MSHADOW_EXTENSION_UNPACK_PATCH2COL_H_
#define MSHADOW_EXTENSION_UNPACK_PATCH2COL_H_
namespace mshadow {
namespace expr {
/*!
* \brief unpack local (overlap) patches of image to column of mat,
* can be used to implement convolution, this expression allow unpack of a batch
* this is a version support unpacking multiple images
* after getting unpacked mat, we can use: output = dot(weight, mat) to get covolved results, the relations:
* \tparam SrcExp source expression
* \tparam dstdim destination dimension
*/
template<typename SrcExp, typename DType, int srcdim>
struct UnpackPatchToColXExp:
public MakeTensorExp<UnpackPatchToColXExp<SrcExp, DType, srcdim>,
SrcExp, 2, DType>{
/*! \brief source operand */
const SrcExp &img_;
/*! \brief patch height */
index_t psize_y_;
/*! \brief patch width */
index_t psize_x_;
/*! \brief patch stride */
index_t pstride_y_;
index_t pstride_x_;
/*! \brief patch dilate */
index_t pdilate_y_;
index_t pdilate_x_;
/*! \brief number of input channel */
index_t i_channel_;
/*! \brief height of img */
index_t i_height_;
/*! \brief width of img */
index_t i_width_;
/*! \brief constructor */
UnpackPatchToColXExp(const SrcExp &img,
index_t psize_y,
index_t psize_x,
index_t pstride_y,
index_t pstride_x,
index_t pdilate_y,
index_t pdilate_x)
: img_(img), psize_y_(psize_y), psize_x_(psize_x),
pstride_y_(pstride_y), pstride_x_(pstride_x),
pdilate_y_(pdilate_y), pdilate_x_(pdilate_x){
Shape<srcdim> imshape = ShapeCheck<srcdim, SrcExp>::Check(img_);
CHECK(imshape[srcdim - 1] >= psize_x && imshape[srcdim - 2] >= psize_y)
<< "UnpackPatchToCol:image shape smaller than patch size";
this->i_channel_ = imshape[srcdim - 3];
this->i_height_ = imshape[srcdim - 2];
this->i_width_ = imshape[srcdim - 1];
// calculate number of batches
const index_t num = imshape.ProdShape(0, srcdim - 3);
const index_t o_height = (i_height_ -
(pdilate_y * (psize_y - 1) + 1)) / pstride_y + 1;
const index_t o_width = (i_width_ -
(pdilate_x * (psize_x - 1) + 1)) / pstride_x + 1;
this->shape_[1] = o_height * o_width * num;
this->shape_[0] = psize_y * psize_x * i_channel_;
}
};
/*!
* \brief unpack local (overlap) patches of image to column of mat, can be used to implement convolution
* after getting unpacked mat, we can use: output = dot(weight, mat) to get covolved results, the relations:
*
* weight; shape[0]: out_channel, shape[1]: ichannel * psize_y * psize_x
* output; shape[0]: out_channel, shape[1]: out_height * out_width * num_of_images
* out_height = (in_height - psize_y) / pstride + 1, this means we pad inperfect patch with 0
* out_width = (in_width - psize_x) / pstride + 1
*
* \return mat target matrix; shape[0]: in_channel*psize_y*psize_x shape[1]: out_height*out_width * num_of_images
* \param img source image; shape[-3]: in_channels, shape[-2]: in_height, shape[-1]: in_width, can be 3D or 4D tensor(multiple images)
* \param psize_y height of each patch
* \param psize_x width of each patch
* \param pstride stride of each patch
* \param pdilate dilate of each patch
* \tparam SrcExp source expression
* \tparam DType the type of elements
* \tparam etype type of expression
*/
template<typename SrcExp, typename DType, int etype>
inline UnpackPatchToColXExp<SrcExp, DType, ExpInfo<SrcExp>::kDim>
unpack_patch2col(const Exp<SrcExp, DType, etype> &img,
index_t psize_y, index_t psize_x, index_t pstride, index_t pdilate) {
TypeCheckPass<ExpInfo<SrcExp>::kDim >= 3>
::Error_Expression_Does_Not_Meet_Dimension_Req();
return UnpackPatchToColXExp<SrcExp, DType, ExpInfo<SrcExp>::kDim>
(img.self(), psize_y, psize_x, pstride, pstride, pdilate, pdilate);
}
/*!
*if you want to specify stride_x and stride_y
*/
template<typename SrcExp, typename DType, int etype>
inline UnpackPatchToColXExp<SrcExp, DType, ExpInfo<SrcExp>::kDim>
unpack_patch2col(const Exp<SrcExp, DType, etype> &img,
index_t psize_y, index_t psize_x, index_t pstride_y_, index_t pstride_x_,
index_t pdilate_y_, index_t pdilate_x_) {
TypeCheckPass<ExpInfo<SrcExp>::kDim >= 3>
::Error_Expression_Does_Not_Meet_Dimension_Req();
return UnpackPatchToColXExp<SrcExp, DType, ExpInfo<SrcExp>::kDim>
(img.self(), psize_y, psize_x, pstride_y_, pstride_x_, pdilate_y_, pdilate_x_);
}
//----------------------
// Execution plan
//----------------------
template<typename SrcExp, typename DType, int srcdim>
struct Plan<UnpackPatchToColXExp<SrcExp, DType, srcdim>, DType> {
public:
explicit Plan(const UnpackPatchToColXExp<SrcExp, DType, srcdim> &e)
:src_(MakePlan(e.img_)),
psize_y_(e.psize_y_), psize_x_(e.psize_x_),
pstride_y_(e.pstride_y_), pstride_x_(e.pstride_x_),
i_channel_(e.i_channel_), pdilate_y_(e.pdilate_y_), pdilate_x_(e.pdilate_x_),
i_height_(e.i_height_), i_width_(e.i_width_),
o_height_((i_height_ - (pdilate_y_ * (psize_y_ - 1) + 1)) / pstride_y_ + 1),
o_width_((i_width_ - (pdilate_x_ * (psize_x_ - 1) + 1)) / pstride_x_ + 1) {}
MSHADOW_XINLINE DType Eval(index_t i, index_t j) const {
const index_t x_offset = i % psize_x_ * pdilate_x_;
const index_t idivp = i / psize_x_;
const index_t y_offset = idivp % psize_y_ * pdilate_y_;
const index_t c = idivp / psize_y_;
const index_t x = (j % o_width_) * pstride_x_ + x_offset;
const index_t jdivw = j / o_width_;
const index_t y = (jdivw % o_height_) * pstride_y_ + y_offset;
const index_t n = jdivw / o_height_;
if (x < i_width_ && y < i_height_) {
return src_.Eval((n * i_channel_ + c) * i_height_ + y, x);
} else {
return DType(0.0f);
}
}
private:
Plan<SrcExp, DType> src_;
const index_t psize_y_, psize_x_, pstride_y_, pstride_x_, i_channel_;
const index_t pdilate_y_, pdilate_x_;
const index_t i_height_, i_width_, o_height_, o_width_;
};
} // namespace expr
} // namespace mshadow
#endif // MSHADOW_EXTENSION_UNPACK_PATCH2COL_H_
//===== EXPANDED: ../mshadow/mshadow/extension/unpack_patch2col.h =====
//===== EXPANDING: ../mshadow/mshadow/extension/pack_col2patch.h =====
/*!
* Copyright (c) 2014 by Contributors
* \file pack_col2patch.h
* \brief support for pack
* \author Tianqi Chen
*/
#ifndef MSHADOW_EXTENSION_PACK_COL2PATCH_H_
#define MSHADOW_EXTENSION_PACK_COL2PATCH_H_
namespace mshadow {
namespace expr {
/*!
* \brief reverse operation of UnpackPatchToCol,
* used to backprop gradient back
* this is a version supporting multiple images
* \tparam SrcExp source expression
* \tparam DType the type of elements
* \tparam dstdim destination dimension
*/
template<typename SrcExp, typename DType, int dstdim>
struct PackColToPatchXExp:
public MakeTensorExp<PackColToPatchXExp<SrcExp, DType, dstdim>,
SrcExp, dstdim, DType> {
/*! \brief source operand */
const SrcExp &src_;
/*! \brief patch height */
index_t psize_y_;
/*! \brief patch height */
index_t psize_x_;
/*! \brief patch stride */
index_t pstride_y_;
index_t pstride_x_;
/*! \brief patch dilate */
index_t pdilate_y_;
index_t pdilate_x_;
/*! \brief constructor */
PackColToPatchXExp(const SrcExp &src, Shape<dstdim> imshape,
index_t psize_y, index_t psize_x,
index_t pstride_y, index_t pstride_x,
index_t pdilate_y, index_t pdilate_x)
:src_(src), psize_y_(psize_y), psize_x_(psize_x),
pstride_y_(pstride_y), pstride_x_(pstride_x),
pdilate_y_(pdilate_y), pdilate_x_(pdilate_x){
this->shape_ = imshape;
const index_t o_height = (imshape[dstdim - 2] -
(pdilate_y * (psize_y - 1)+ 1))/pstride_y + 1;
const index_t o_width = (imshape[dstdim - 1] -
(pdilate_x * (psize_x - 1) + 1)) / pstride_x + 1;
Shape<2> sshape = ShapeCheck<2, SrcExp>::Check(src_);
CHECK_EQ(sshape[1], o_height * o_width * imshape.ProdShape(0, dstdim - 3))
<< "PackColToPatchExp: src.size(1) mismatch";
CHECK_EQ(sshape[0], psize_y * psize_x * imshape[dstdim - 3])
<< "PackColToPatchExp: src.size(0) mismatch";
}
};
/*!
* \brief reverse operation of pack_col2patch, can be used to implement deconvolution
* \return packed img expression
* \param mat source matrix
* \param imshape shape of target img
* \param psize_y height of each patch
* \param psize_x height of each patch
* \param pstride stride of each patch
* \tparam SrcExp source expression
* \tparam DType the type of elements
* \tparam dstdim destination dimension
* \tparam etype type of expression
*/
template<typename SrcExp, typename DType, int dstdim, int etype>
inline PackColToPatchXExp<SrcExp, DType, dstdim>
pack_col2patch(const expr::Exp<SrcExp, DType, etype> &src,
Shape<dstdim> imshape, index_t psize_y,
index_t psize_x, index_t pstride, index_t pdilate) {
TypeCheckPass<ExpInfo<SrcExp>::kDim == 2>
::Error_Expression_Does_Not_Meet_Dimension_Req();
CHECK(imshape[dstdim - 1] >= psize_x && imshape[dstdim - 2] >= psize_y)
<< "PackColToPatch:image shape smaller than patch size";
return PackColToPatchXExp<SrcExp, DType, dstdim>(src.self(), imshape,
psize_y, psize_x, pstride, pstride,
pdilate, pdilate);
}
/*!
*if you want to specify kstride_y and kstride_x
*/
template<typename SrcExp, typename DType, int dstdim, int etype>
inline PackColToPatchXExp<SrcExp, DType, dstdim>
pack_col2patch(const expr::Exp<SrcExp, DType, etype> &src,
Shape<dstdim> imshape, index_t psize_y,
index_t psize_x, index_t pstride_y, index_t pstride_x,
index_t pdilate_y, index_t pdilate_x) {
TypeCheckPass<ExpInfo<SrcExp>::kDim == 2>
::Error_Expression_Does_Not_Meet_Dimension_Req();
CHECK(imshape[dstdim - 1] >= psize_x && imshape[dstdim - 2] >= psize_y)
<< "PackColToPatch:image shape smaller than patch size";
return PackColToPatchXExp<SrcExp, DType, dstdim>(src.self(), imshape,
psize_y, psize_x, pstride_y, pstride_x,
pdilate_y, pdilate_x);
}
//----------------------
// Execution plan
//----------------------
template<typename SrcExp, typename DType, int dstdim>
struct Plan<PackColToPatchXExp<SrcExp, DType, dstdim>, DType> {
public:
explicit Plan(const PackColToPatchXExp<SrcExp, DType, dstdim> &e)
:src_(MakePlan(e.src_)), psize_y_(e.psize_y_),
psize_x_(e.psize_x_), pstride_y_(e.pstride_y_), pstride_x_(e.pstride_x_),
i_channel_(e.shape_[dstdim - 3]), pdilate_y_(e.pdilate_y_), pdilate_x_(e.pdilate_x_),
i_height_(e.shape_[dstdim - 2]),
o_height_((e.shape_[dstdim - 2] - (pdilate_y_ * (psize_y_ - 1) + 1)) /
pstride_y_ + 1),
o_width_((e.shape_[dstdim - 1] - (pdilate_x_ * (psize_x_ - 1) + 1)) /
pstride_x_ + 1) {
// note: i/o convention are same as unpack
}
MSHADOW_XINLINE DType Eval(index_t i, index_t j) const {
using namespace std;
const index_t y = i % i_height_;
const index_t idivh = i / i_height_;
const index_t c = idivh % i_channel_;
const index_t n = idivh / i_channel_;
const index_t x = j;
const index_t psize_y_dilate = (pdilate_y_ * (psize_y_ - 1) + 1);
const index_t psize_x_dilate = (pdilate_x_ * (psize_x_ - 1) + 1);
const index_t py_min =
y < psize_y_dilate ? y % pdilate_y_ : (y-psize_y_dilate + pstride_y_) / pstride_y_;
const index_t px_min =
x < psize_x_dilate ? x % pdilate_x_ : (x-psize_x_dilate + pstride_x_) / pstride_x_;
const index_t py_max = min((y + pstride_y_) / pstride_y_, o_height_);
const index_t px_max = min((x + pstride_x_) / pstride_x_, o_width_);
DType res = static_cast<DType>(0);
for (index_t py = py_min; py < py_max; py += pdilate_y_) {
for (index_t px = px_min; px < px_max; px += pdilate_x_) {
res += src_.Eval(((c * psize_y_ + (y - py*pstride_y_) / pdilate_y_) * psize_x_ +
(x - px * pstride_x_) / pdilate_x_),
(n * o_height_ + py) * o_width_ + px);
}
}
return res;
}
private:
Plan<SrcExp, DType> src_;
const index_t psize_y_, psize_x_, pstride_y_, pstride_x_, i_channel_;
const index_t pdilate_y_, pdilate_x_;
const index_t i_height_, o_height_, o_width_;
};
} // namespace expr
} // namespace mshadow
#endif // MSHADOW_EXTENSION_PACK_COL2PATCH_H_
//===== EXPANDED: ../mshadow/mshadow/extension/pack_col2patch.h =====
//===== EXPANDING: ../mshadow/mshadow/extension/reshape.h =====
/*!
* Copyright (c) 2014 by Contributors
* \file reshape.h
* \brief support for reshape
* \author Tianqi Chen
*/
#ifndef MSHADOW_EXTENSION_RESHAPE_H_
#define MSHADOW_EXTENSION_RESHAPE_H_
namespace mshadow {
namespace expr {
/*!
* \brief reshape the content to another shape
* input: Tensor<Device,dimsrc>: ishape
* output: Tensor<Device,dimdst> ishape.Size() == oshape.Size()
* \tparam SrcExp source expression
* \tparam dimdst target dimension
* \tparam dimsrc source dimension
*/
template<typename SrcExp, typename DType, int dimdst, int dimsrc>
struct ReshapeExp:
public MakeTensorExp<ReshapeExp<SrcExp, DType, dimdst, dimsrc>,
SrcExp, dimdst, DType> {
/*! \brief source expression */
const SrcExp &src_;
/*! \brief smallest dimension of input */
index_t ishapex_;
/*! \brief constructor */
ReshapeExp(const SrcExp &src, Shape<dimdst> shape)
: src_(src) {
Shape<dimsrc> ishape = ShapeCheck<dimsrc, SrcExp>::Check(src_);
CHECK_EQ(ishape.Size(), shape.Size()) << "reshape size must match";
ishapex_ = ishape[dimsrc - 1];
this->shape_ = shape;
}
};
/*!
* \brief a expression that reshapes a tensor to another shape
* \param src Tensor<Device,dimsrc>:
* \param oshape target shape
* \return a expresion with type Tensor<Device,dimdst>
* \tparam SrcExp source expression
* \tparam etype source expression type
* \tparam dimdst target dimension
*/
template<typename SrcExp, typename DType, int etype, int dimdst>
inline ReshapeExp<SrcExp, DType, dimdst, ExpInfo<SrcExp>::kDim>
reshape(const Exp<SrcExp, DType, etype> &src, Shape<dimdst> oshape) {
return ReshapeExp<SrcExp, DType, dimdst, ExpInfo<SrcExp>::kDim>
(src.self(), oshape);
}
//----------------------
// Execution plan
//----------------------
template<typename SrcExp, typename DType, int dimdst, int dimsrc>
struct Plan<ReshapeExp<SrcExp, DType, dimdst, dimsrc>, DType> {
public:
explicit Plan(const ReshapeExp<SrcExp, DType, dimdst, dimsrc> &e)
: src_(MakePlan(e.src_)),
oshapex_(e.shape_[dimdst - 1]), ishapex_(e.ishapex_) {}
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const {
const index_t idx = y * oshapex_ + x;
return src_.Eval(idx / ishapex_, idx % ishapex_);
}
private:
Plan<SrcExp, DType> src_;
const index_t oshapex_, ishapex_;
};
// special work plan for 1 dimensional data
template<typename SrcExp, typename DType, int dimdst>
struct Plan<ReshapeExp<SrcExp, DType, dimdst, 1>, DType> {
public:
explicit Plan(const ReshapeExp<SrcExp, DType, dimdst, 1> &e)
: src_(MakePlan(e.src_)), oshapex_(e.shape_[dimdst - 1]) {
}
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const {
return src_.Eval(0, y * oshapex_ + x);
}
private:
Plan<SrcExp, DType> src_;
const index_t oshapex_;
};
} // namespace expr
} // namespace mshadow
#endif // MSHADOW_EXTENSION_RESHAPE_H_
//===== EXPANDED: ../mshadow/mshadow/extension/reshape.h =====
//===== EXPANDING: ../mshadow/mshadow/extension/swapaxis.h =====
/*!
* Copyright (c) 2014 by Contributors
* \file swapaxis.h
* \brief support for swapaxis
* \author Tianqi Chen
*/
#ifndef MSHADOW_EXTENSION_SWAPAXIS_H_
#define MSHADOW_EXTENSION_SWAPAXIS_H_
namespace mshadow {
namespace expr {
/*!
* \brief swap two axis of a tensor
* input: Tensor<Device,dim>: ishape
* output: Tensor<Device,dimdst> oshape[a1],oshape[a2] = ishape[a2],oshape[a1]
*
* \tparam SrcExp type of source expression
* \tparam DType the type of elements
* \tparam dimsrc source dimension, assert a1 > a2
* \tparam m_a1 one dimension to be swapped, encoded by dimsrc - a1
* \tparam a2 second dimension to be swapped, encoded by a2
*/
template<typename SrcExp, typename DType, int dimsrc, int m_a1, int a2>
struct SwapAxisExp:
public MakeTensorExp<SwapAxisExp<SrcExp, DType, dimsrc, m_a1, a2>,
SrcExp, dimsrc, DType> {
// decode the a1, a2
static const int a1 = dimsrc - m_a1;
/*! \brief source expression */
const SrcExp &src_;
/*! \brief constructor */
explicit SwapAxisExp(const SrcExp &src) : src_(src) {
this->shape_ = ShapeCheck<dimsrc, SrcExp>::Check(src);
std::swap(this->shape_[a1], this->shape_[a2]);
}
};
/*!
* \brief a expression that reshapes a tensor to another shape
* \param src Tensor<Device,dimsrc>:
* \return a expresion with type Tensor<Device,dimdst>
* \tparam a1 higher dimension to be swapped, assert a1 > a2
* \tparam a2 lower dimension to be swapped
* \tparam SrcExp source expression
* \tparam DType the type of elements
* \tparam etype source expression type
*/
template<int a1, int a2, typename SrcExp, typename DType, int etype>
inline SwapAxisExp<SrcExp, DType, ExpInfo<SrcExp>::kDim,
ExpInfo<SrcExp>::kDim - a1, a2>
swapaxis(const Exp<SrcExp, DType, etype> &src) {
typedef ExpInfo<SrcExp> Info;
TypeCheckPass<Info::kDim >= a1 + 1 && Info::kDim >= a2 + 1 &&
a2 < a1>::Error_Expression_Does_Not_Meet_Dimension_Req();
return SwapAxisExp<SrcExp, DType, ExpInfo<SrcExp>::kDim,
ExpInfo<SrcExp>::kDim - a1, a2>(src.self());
}
template<typename SrcExp, typename DType, int dimsrc, int m_a1, int a2>
struct Plan<SwapAxisExp<SrcExp, DType, dimsrc, m_a1, a2>, DType> {
public:
// decode the a1
static const int a1 = dimsrc - m_a1;
explicit Plan(const SwapAxisExp<SrcExp, DType, dimsrc, m_a1, a2> &e)
: src_(MakePlan(e.src_)),
shapey_(e.shape_.ProdShape(a1 + 1, dimsrc - 1)),
shapez_(e.shape_[a1]),
shapec_(e.shape_.ProdShape(a2 + 1, a1)),
shapen_(e.shape_[a2]) {}
MSHADOW_XINLINE DType Eval(index_t i, index_t j) const {
const index_t y = i % shapey_;
i /= shapey_;
const index_t z = i % shapez_;
i /= shapez_;
const index_t c = i % shapec_;
i /= shapec_;
const index_t n = i % shapen_;
// swap z and n
return src_.Eval(((((i / shapen_) * shapez_ + z) * shapec_ +
c) * shapen_ + n) * shapey_ + y, j);
}
private:
Plan<SrcExp, DType> src_;
const index_t shapey_, shapez_, shapec_, shapen_;
};
template<typename SrcExp, typename DType, int dimsrc, int a2>
struct Plan<SwapAxisExp<SrcExp, DType, dimsrc, 1, a2>, DType> {
public:
explicit Plan(const SwapAxisExp<SrcExp, DType, dimsrc, 1, a2> &e)
: src_(MakePlan(e.src_)),
shapex_(e.shape_[dimsrc - 1]),
shapey_(e.shape_.ProdShape(a2 + 1, dimsrc - 1)),
shapez_(e.shape_[a2]) {}
MSHADOW_XINLINE DType Eval(index_t i, index_t x) const {
// swap x and z
const index_t y = i % shapey_;
i /= shapey_;
const index_t z = i % shapez_;
const index_t n = i / shapez_;
return src_.Eval((n * shapex_ + x) * shapey_ + y , z);
}
private:
Plan<SrcExp, DType> src_;
const index_t shapex_, shapey_, shapez_;
};
} // namespace expr
} // namespace mshadow
#endif // MSHADOW_EXTENSION_SWAPAXIS_H_
//===== EXPANDED: ../mshadow/mshadow/extension/swapaxis.h =====
//===== EXPANDING: ../mshadow/mshadow/extension/reduceto1d.h =====
/*!
* Copyright (c) 2014 by Contributors
* \file reduceto1d.h
* \brief support for sum_rows and sumall_except_dim
* \author Tianqi Chen
*/
#ifndef MSHADOW_EXTENSION_REDUCETO1D_H_
#define MSHADOW_EXTENSION_REDUCETO1D_H_
namespace mshadow {
namespace expr {
/*!
* \brief reduction to 1 dimension tensor
* input: Tensor<Device,k>: ishape
* output: Tensor<Device,1> shape[0] = ishape[dimkeep];
*
* \tparam SrcExp type of expression to be reduced
* \tparam DType the data type of the scalar
* \tparam Reducer which reducer to use
* \tparam m_dimkeep which dimension to be kept, encoded with dimsrc - dimkeep
*/
template<typename SrcExp, typename DType, typename Reducer, int m_dimkeep>
struct ReduceTo1DExp:
public Exp<ReduceTo1DExp<SrcExp, DType, Reducer, m_dimkeep>,
DType, type::kComplex> {
/*! \brief source operand */
const SrcExp &src_;
/*! \brief source operand, scale of the */
DType scale_;
/*! \brief construct a repmat expression from src and nrow */
ReduceTo1DExp(const SrcExp& src, DType scale) : src_(src), scale_(scale) {}
};
/*!
* \brief a sum over all dimensions, except dimkeep
* \param exp input expression that must be a matrix Tensor<?,2>
* \return a expresion with type Tensor<Device,1>
* \tparam dimkeep the dimension that will be kept
* \tparam SrcExp expression
* \tparam etype type of expression
*/
template<int dimkeep, typename SrcExp, typename DType, int etype>
inline ReduceTo1DExp<SrcExp, DType, red::sum,
ExpInfo<SrcExp>::kDim - dimkeep>
sumall_except_dim(const Exp<SrcExp, DType, etype> &exp) {
return ReduceTo1DExp<SrcExp, DType, red::sum,
ExpInfo<SrcExp>::kDim - dimkeep>(exp.self(), DType(1));
}
/*!
* \brief reduce over all dimensions, except dimkeep
* \param exp input expression that must be a matrix Tensor<?,2>
* \return a expresion with type Tensor<Device,1>
* \tparam dimkeep the dimension that will be kept
* \tparam SrcExp expression
* \tparam etype type of expression
*/
template<int dimkeep, typename Reducer, typename SrcExp, typename DType, int etype>
inline ReduceTo1DExp<SrcExp, DType, Reducer,
ExpInfo<SrcExp>::kDim - dimkeep>
reduce_except_dim(const Exp<SrcExp, DType, etype> &exp) {
return ReduceTo1DExp<SrcExp, DType, Reducer,
ExpInfo<SrcExp>::kDim - dimkeep>(exp.self(), DType(1));
}
/*!
* \brief a expression that sum over rows of a matrix
* \param exp input expression that must be a matrix Tensor<?, 2>
* \return a expresion with type Tensor<Device, 1>
* \tparam SrcExp expression
* \tparam etype type of expression
*/
template<typename SrcExp, typename DType, int etype>
inline ReduceTo1DExp<SrcExp, DType, red::sum, 1>
sum_rows(const Exp<SrcExp, DType, etype> &exp) {
TypeCheckPass<ExpInfo<SrcExp>::kDim ==2>
::Error_Expression_Does_Not_Meet_Dimension_Req();
return sumall_except_dim<1>(exp);
}
template<typename SV, typename Device, typename DType,
typename SrcExp, typename Reducer, int m_dimkeep>
struct ExpComplexEngine<SV,
Tensor<Device, 1, DType>,
ReduceTo1DExp<SrcExp, DType, Reducer, m_dimkeep>,
DType> {
static const int dimkeep = ExpInfo<SrcExp>::kDim - m_dimkeep;
inline static void Eval(Tensor<Device, 1, DType> *dst,
const ReduceTo1DExp<SrcExp, DType,
Reducer, m_dimkeep> &exp) {
TypeCheckPass<m_dimkeep != 1>
::Error_Expression_Does_Not_Meet_Dimension_Req();
MapReduceKeepHighDim<SV, Reducer, dimkeep>(dst, exp.src_, exp.scale_);
}
};
template<typename SV, typename Device, typename DType,
typename SrcExp, typename Reducer>
struct ExpComplexEngine<SV,
Tensor<Device, 1, DType>,
ReduceTo1DExp<SrcExp, DType, Reducer, 1>, DType> {
inline static void Eval(Tensor<Device, 1, DType> *dst,
const ReduceTo1DExp<SrcExp, DType, Reducer, 1> &exp) {
MapReduceKeepLowest<SV, Reducer>(dst, exp.src_, exp.scale_);
}
};
} // namespace expr
} // namespace mshadow
#endif // MSHADOW_EXTENSION_REDUCETO1D_H_
//===== EXPANDED: ../mshadow/mshadow/extension/reduceto1d.h =====
//===== EXPANDING: ../mshadow/mshadow/extension/spatial_pool.h =====
/*!
* Copyright (c) 2014 by Contributors
* \file spatial_pool.h
* \brief support for spatial pooling
* \author Tianqi Chen
*/
#ifndef MSHADOW_EXTENSION_SPATIAL_POOL_H_
#define MSHADOW_EXTENSION_SPATIAL_POOL_H_
namespace mshadow {
namespace expr {
/*!
* \brief pooling expression, do reduction over local patches of a image
* \tparam Reducer reduction method during pooling
* \tparam SrcExp source expression to be pooled from
* \tparam DType the content data type
* \tparam srcdim dimension of src
*/
template<typename Reducer, typename SrcExp, typename DType, int srcdim>
struct PoolingExp:
public MakeTensorExp<PoolingExp<Reducer, SrcExp, DType, srcdim>,
SrcExp, srcdim, DType> {
/*! \brief source operand */
const SrcExp &src_;
/*! \brief kernel size in height */
index_t ksize_y_;
/*! \brief kernel size in width */
index_t ksize_x_;
/*! \brief kernel stride in y directory */
index_t kstride_y_;
/*! \brief kernel stride in x directory */
index_t kstride_x_;
/*! \brief source height shape[1] */
index_t src_height_;
/*! \brief source width shape[0] */
index_t src_width_;
/*! \brief constructor */
PoolingExp(const SrcExp &src,
index_t ksize_y, index_t ksize_x, index_t kstride_y, index_t kstride_x)
: src_(src), ksize_y_(ksize_y), ksize_x_(ksize_x),
kstride_y_(kstride_y), kstride_x_(kstride_x) {
Shape<srcdim> sshape = ShapeCheck<srcdim, SrcExp>::Check(src_);
CHECK(sshape[srcdim - 1] >= ksize_x && sshape[srcdim - 2] >= ksize_y)
<< "PoolingExp: kernel must be smaller than image";
this->src_height_ = sshape[srcdim - 2];
this->src_width_ = sshape[srcdim - 1];
this->shape_ = sshape;
this->shape_[srcdim - 2] = (src_height_ - ksize_y) / kstride_y + 1;
this->shape_[srcdim - 1] = (src_width_ - ksize_x) / kstride_x + 1;
}
/*! \brief constructor, specify shape */
PoolingExp(const SrcExp &src, Shape<2> pshape,
index_t ksize_y, index_t ksize_x, index_t kstride_y, index_t kstride_x)
: src_(src), ksize_y_(ksize_y), ksize_x_(ksize_x),
kstride_y_(kstride_y), kstride_x_(kstride_x) {
Shape<srcdim> sshape = ShapeCheck<srcdim, SrcExp>::Check(src_);
CHECK(sshape[srcdim - 1] >= ksize_x && sshape[srcdim - 2] >= ksize_y)
<< "PoolingExp: kernel must be smaller than image";
this->src_height_ = sshape[srcdim - 2];
this->src_width_ = sshape[srcdim - 1];
this->shape_ = sshape;
this->shape_[srcdim - 2] = pshape[0];
this->shape_[srcdim - 1] = pshape[1];
}
};
/*!
* \brief pooling subregion results together
* \param src source image, shape: (batch, channel, height, width)
* \param ksize_y kernel size in height
* \param ksize_x kernel size in width
* \param kstride_y stride in y directory
* \param kstride_x stride in x directory
* \return expression of pooled result
* \tparam Reducer reducer type
* \tparam SrcExp source expression
* \tparam DType the content data type
* \tparam etype type of expression
*/
template<typename Reducer, typename SrcExp, typename DType, int etype>
inline PoolingExp<Reducer, SrcExp, DType, ExpInfo<SrcExp>::kDim>
pool(const Exp<SrcExp, DType, etype> &src,
index_t ksize_y, index_t ksize_x, index_t kstride_y, index_t kstride_x) {
TypeCheckPass<ExpInfo<SrcExp>::kDim >= 2>
::Error_Expression_Does_Not_Meet_Dimension_Req();
return PoolingExp<Reducer, SrcExp, DType, ExpInfo<SrcExp>::kDim>
(src.self(), ksize_y, ksize_x, kstride_y, kstride_x);
}
/*!
* \brief same as pool, except the output shape is specified by pshape
* \param src source image
* \param pshape ouput shape
* \param ksize_y kernel size in y
* \param ksize_x kernel size in x
* \param kstride_y stride in y directory
* \param kstride_x stride in x directory
* \return expression of pooled result
* \tparam Reducer reducer type
* \tparam SrcExp source expression
* \tparam DType the content data type
* \tparam etype type of expression
*/
template<typename Reducer, typename SrcExp,
typename DType, int etype>
inline PoolingExp<Reducer, SrcExp, DType, ExpInfo<SrcExp>::kDim>
pool(const Exp<SrcExp, DType, etype> &src, Shape<2> pshape,
index_t ksize_y, index_t ksize_x, index_t kstride_y, index_t kstride_x) {
TypeCheckPass<ExpInfo<SrcExp>::kDim >= 2>
::Error_Expression_Does_Not_Meet_Dimension_Req();
return PoolingExp<Reducer, SrcExp, DType, ExpInfo<SrcExp>::kDim>
(src.self(), pshape, ksize_y, ksize_x, kstride_y, kstride_x);
}
//----------------------
// Execution plan
//----------------------
template<typename Reducer, typename SrcExp, typename DType, int srcdim>
struct Plan<PoolingExp< Reducer, SrcExp, DType, srcdim>, DType> {
public:
explicit Plan(const PoolingExp<Reducer, SrcExp, DType, srcdim> &e)
: src_(MakePlan(e.src_)),
ksize_y_(e.ksize_y_), ksize_x_(e.ksize_x_),
kstride_y_(e.kstride_y_), kstride_x_(e.kstride_x_),
src_height_(e.src_height_), src_width_(e.src_width_),
new_height_(e.shape_[srcdim - 2]) {}
MSHADOW_XINLINE DType Eval(index_t i, index_t j) const {
using namespace std;
const index_t py = i % new_height_;
const index_t y_start = py * kstride_y_;
const index_t y_end = min(y_start + ksize_y_, src_height_);
const index_t px = j;
const index_t x_start = px * kstride_x_;
const index_t x_end = min(x_start + ksize_x_, src_width_);
const index_t c = i / new_height_;
DType res; Reducer::SetInitValue(res);
for (index_t y = y_start; y < y_end; ++y) {
for (index_t x = x_start; x < x_end; ++x) {
Reducer::Reduce(res, src_.Eval(c * src_height_ + y, x));
}
}
return res;
}
private:
Plan<SrcExp, DType> src_;
const index_t ksize_y_, ksize_x_, kstride_y_, kstride_x_;
const index_t src_height_, src_width_;
const index_t new_height_;
};
} // namespace expr
} // namespace mshadow
#endif // MSHADOW_EXTENSION_SPATIAL_POOL_H_
//===== EXPANDED: ../mshadow/mshadow/extension/spatial_pool.h =====
//===== EXPANDING: ../mshadow/mshadow/extension/spatial_unpool.h =====
/*!
* Copyright (c) 2014 by Contributors
* \file spatial_unpool.h
* \brief support for unpool
* \author Tianqi Chen
*/
#ifndef MSHADOW_EXTENSION_SPATIAL_UNPOOL_H_
#define MSHADOW_EXTENSION_SPATIAL_UNPOOL_H_
namespace mshadow {
namespace expr {
/*!
* \brief unpooling expr reverse operation of pooling, used to pass gradient back
* \tparam Reducer reduction method during pooling
* \tparam SrcExp source expression to be pooled from
* \tparam DType the content data type
* \tparam srcdim dimension of src
*/
template<typename Reducer, typename SrcExp, typename DType, int srcdim>
struct UnPoolingExp:
public MakeTensorExp<UnPoolingExp<Reducer, SrcExp, DType, srcdim>,
SrcExp, srcdim, DType> {
/*! \brief source input, corresponds to src in pooling */
const SrcExp &data_src_;
/*! \brief result of pooled data, corresponds to result of pooling */
const SrcExp &data_pooled_;
/*! \brief gradient data of pooled part, to be propgate down */
const SrcExp &grad_pooled_;
/*! \brief shape of pooled expression */
index_t pshape_y_;
/*! \brief shape of pooled expression */
index_t pshape_x_;
/*! \brief kernel size in height */
index_t ksize_y_;
/*! \brief kernel size in width */
index_t ksize_x_;
/*! \brief kernel stride in y directory */
index_t kstride_y_;
/*! \brief kernel stride in x directory */
index_t kstride_x_;
/*! \brief constructor */
UnPoolingExp(const SrcExp &data_src,
const SrcExp &data_pooled,
const SrcExp &grad_pooled,
index_t ksize_y, index_t ksize_x, index_t kstride_y, index_t kstride_x)
: data_src_(data_src), data_pooled_(data_pooled),
grad_pooled_(grad_pooled),
ksize_y_(ksize_y), ksize_x_(ksize_x),
kstride_y_(kstride_y), kstride_x_(kstride_x) {
Shape<srcdim> pshape = ShapeCheck<srcdim, SrcExp>::Check(grad_pooled);
typedef ShapeCheck<srcdim, SrcExp> ShapeCheckSrcDimSrcExp;
CHECK_EQ(pshape, ShapeCheckSrcDimSrcExp::Check(data_pooled))
<< "UnPoolingExp: pooled shape mismatch";
Shape<srcdim> sshape = ShapeCheck<srcdim, SrcExp>::Check(data_src);
for (int k = 0; k < srcdim - 2; ++k) {
CHECK_EQ(pshape[k], sshape[k]) << "UnPoolingExp: pool and src shape mismatch";
}
pshape_x_ = pshape[srcdim - 1];
pshape_y_ = pshape[srcdim - 2];
this->shape_ = sshape;
}
};
/*!
* \brief unpooling gradient for 4D, backprop gradient value back, revserse operation of pooling,
* same as unpooling, but allows unequal size of kernel
* \param data_src source input, corresponds to src in pooling
* \param data_pooled result of pooled data, corresponds to result of pooling
* \param grad_pooled gradient data of pooled part, to be propgate down
* \param ksize_y kernel height
* \param ksize_x kernel width
* \param kstride_y stride in y directory
* \param kstride_x stride in x directory
* \return expression corresponding to unpooled 4D Tensor, storing backproped gradient
* \tparam Reducer reducer type
* \tparam SrcExp source expression
* \tparam DType the content data type
* \tparam etype type of expression
*/
template<typename Reducer, typename SrcExp, typename DType, int etype>
inline UnPoolingExp<Reducer, SrcExp, DType, ExpInfo<SrcExp>::kDim>
unpool(const Exp<SrcExp, DType, etype> &data_src,
const Exp<SrcExp, DType, etype> &data_pooled,
const Exp<SrcExp, DType, etype> &grad_pooled,
index_t ksize_y, index_t ksize_x, index_t kstride_y, index_t kstride_x) {
return UnPoolingExp<Reducer, SrcExp, DType, ExpInfo<SrcExp>::kDim>
(data_src.self(), data_pooled.self(), grad_pooled.self(),
ksize_y, ksize_x, kstride_y, kstride_x);
}
//----------------------
// Execution plan
//----------------------
template<typename Reducer, typename SrcExp, typename DType, int srcdim>
struct Plan<UnPoolingExp<Reducer, SrcExp, DType, srcdim>, DType> {
public:
explicit Plan(const UnPoolingExp<Reducer, SrcExp, DType, srcdim> &e)
: data_src_(MakePlan(e.data_src_)), data_pooled_(MakePlan(e.data_pooled_)),
grad_pooled_(MakePlan(e.grad_pooled_)), sshape_y_(e.shape_[srcdim - 2]),
pshape_y_(e.pshape_y_), pshape_x_(e.pshape_x_),
ksize_y_(e.ksize_y_), ksize_x_(e.ksize_x_),
kstride_y_(e.kstride_y_), kstride_x_(e.kstride_x_) {}
MSHADOW_XINLINE DType Eval(index_t i, index_t j) const {
using namespace std;
const index_t x = j;
const index_t y = i % sshape_y_;
const index_t c = i / sshape_y_;
const DType vsrc = data_src_.Eval(i, j);
const index_t py_min =
y < ksize_y_ ? 0 : (y - ksize_y_ + kstride_y_) / kstride_y_;
const index_t px_min =
x < ksize_x_ ? 0 : (x - ksize_x_ + kstride_x_) / kstride_x_;
const index_t py_max = min((y + kstride_y_) / kstride_y_, pshape_y_);
const index_t px_max = min((x + kstride_x_) / kstride_x_, pshape_x_);
DType val = static_cast<DType>(0);
for (index_t py = py_min; py < py_max; ++py) {
for (index_t px = px_min; px < px_max; ++px) {
val += Reducer::PartialGrad(vsrc,
data_pooled_.Eval(c * pshape_y_ + py, px)) *
grad_pooled_.Eval(c * pshape_y_ + py, px);
}
}
return val;
}
private:
Plan<SrcExp, DType> data_src_, data_pooled_, grad_pooled_;
const index_t sshape_y_, pshape_y_, pshape_x_;
const index_t ksize_y_, ksize_x_;
const index_t kstride_y_, kstride_x_;
};
} // namespace expr
} // namespace mshadow
#endif // MSHADOW_EXTENSION_SPATIAL_UNPOOL_H_
//===== EXPANDED: ../mshadow/mshadow/extension/spatial_unpool.h =====
//===== EXPANDING: ../mshadow/mshadow/extension/channel_pool.h =====
/*!
* Copyright (c) 2014 by Contributors
* \file channel_pool.h
* \brief support for chpool
* \author Tianqi Chen
*/
#ifndef MSHADOW_EXTENSION_CHANNEL_POOL_H_
#define MSHADOW_EXTENSION_CHANNEL_POOL_H_
namespace mshadow {
namespace expr {
/*!
* \brief channel pooling expression, do reduction over (local nearby) channels,
* used to implement local response normalization
* \tparam Reducer reduction method during pooling
* \tparam SrcExp source expression to be pooled from
* \tparam DType the type of elements
* \tparam srcdim dimension of src
*/
template<typename Reducer, typename SrcExp, typename DType, int srcdim>
struct ChannelPoolingExp:
public MakeTensorExp<ChannelPoolingExp<Reducer, SrcExp, DType, srcdim>,
SrcExp, srcdim, DType> {
/*! \brief source operand */
const SrcExp &src_;
/*! \brief neighbor size */
index_t nsize_;
/*! \brief stride of pooling */
index_t stride_;
/*! \brief pad of pooling of each side */
index_t pad_;
index_t src_channel_;
/*! \brief constructor */
ChannelPoolingExp(const SrcExp &src, index_t nsize, index_t stride, index_t pad)
: src_(src), nsize_(nsize), stride_(stride), pad_(pad) {
this->shape_ = ShapeCheck<srcdim, SrcExp>::Check(src_);
this->src_channel_ = this->shape_[srcdim - 3];
CHECK_GE(this->shape_[srcdim - 3], nsize_)
<< "chpool: local size must be smaller than nchannels";
this->shape_[srcdim - 3] = (this->src_channel_ - nsize + pad * 2 + 1) / stride;
}
};
/*!
* \brief channel pooling, do reduction over (local nearby) channels,
* used to implement local response normalization
* \param src source data
* \param nsize neighbor size
* \return expression of pooled result
* \tparam Reducer reducer type
* \tparam SrcExp source expression
* \tparam DType the type of elements
* \tparam etype type of expression
*/
template<typename Reducer, typename SrcExp, typename DType, int etype>
inline ChannelPoolingExp<Reducer, SrcExp, DType, ExpInfo<SrcExp>::kDim>
chpool(const Exp<SrcExp, DType, etype> &src, index_t nsize) {
TypeCheckPass<ExpInfo<SrcExp>::kDim >= 3>
::Error_Expression_Does_Not_Meet_Dimension_Req();
CHECK_EQ(nsize % 2, 1) << "chpool: if no pad is specified, local size must be odd";
return ChannelPoolingExp<Reducer, SrcExp,
DType, ExpInfo<SrcExp>::kDim>(src.self(), nsize, 1, nsize / 2);
}
template<typename Reducer, typename SrcExp, typename DType, int etype>
inline ChannelPoolingExp<Reducer, SrcExp, DType, ExpInfo<SrcExp>::kDim>
chpool(const Exp<SrcExp, DType, etype> &src, index_t nsize, index_t stride, index_t pad) {
TypeCheckPass<ExpInfo<SrcExp>::kDim >= 3>
::Error_Expression_Does_Not_Meet_Dimension_Req();
return ChannelPoolingExp<Reducer, SrcExp,
DType, ExpInfo<SrcExp>::kDim>(src.self(), nsize, stride, pad);
}
//----------------------
// Execution plan
//----------------------
template<typename Reducer, typename SrcExp, typename DType, int srcdim>
struct Plan<ChannelPoolingExp<Reducer, SrcExp, DType, srcdim>, DType> {
public:
explicit Plan(const ChannelPoolingExp<Reducer, SrcExp, DType, srcdim> &e)
: src_(MakePlan(e.src_)), channel_(e.shape_[srcdim - 3]),
height_(e.shape_[srcdim - 2]), width_(e.shape_[srcdim - 1]),
hnsize_(e.nsize_), stride_(e.stride_), pad_(e.pad_),
src_channel_(e.src_channel_) {}
MSHADOW_XINLINE DType Eval(index_t i, index_t j) const {
using namespace std;
const index_t y = i % height_;
i /= height_;
const index_t c = i % channel_;
const index_t n = i / channel_;
const index_t x = j;
const index_t cstart = c * stride_ < pad_ ? 0 : c * stride_ - pad_;
const index_t cend = min(cstart + hnsize_, channel_);
DType res; Reducer::SetInitValue(res);
for (index_t cc = cstart; cc < cend; ++cc) {
Reducer::Reduce(res, src_.Eval((n * src_channel_ + cc) * height_ + y, x));
}
return res;
}
private:
Plan<SrcExp, DType> src_;
const index_t channel_, height_, width_, hnsize_, stride_, pad_, src_channel_;
};
} // namespace expr
} // namespace mshadow
#endif // MSHADOW_EXTENSION_CHANNEL_POOL_H_
//===== EXPANDED: ../mshadow/mshadow/extension/channel_pool.h =====
//===== EXPANDING: ../mshadow/mshadow/extension/channel_unpool.h =====
/*!
* Copyright (c) 2014 by Contributors
* \file channel_pool.h
* \brief support for chpool
* \author Tianqi Chen
*/
#ifndef MSHADOW_EXTENSION_CHANNEL_UNPOOL_H_
#define MSHADOW_EXTENSION_CHANNEL_UNPOOL_H_
namespace mshadow {
namespace expr {
/*!
* \brief channel pooling expression, do reduction over (local nearby) channels,
* used to implement local response normalization
* \tparam Reducer reduction method during pooling
* \tparam SrcExp source expression to be pooled from
* \tparam DType the type of elements
* \tparam srcdim dimension of src
*/
template<typename Reducer, typename SrcExp, typename DType, int srcdim>
struct ChannelUnpoolingExp:
public MakeTensorExp<ChannelUnpoolingExp<Reducer, SrcExp, DType, srcdim>,
SrcExp, srcdim, DType> {
/*! \brief source input, corresponds to src in pooling */
const SrcExp &data_src_;
/*! \brief result of pooled data, corresponds to result of pooling */
const SrcExp &data_pooled_;
/*! \brief gradient data of pooled part, to be propgate down */
const SrcExp &grad_pooled_;
/*! \brief channel of pooled expression */
index_t pchannel_;
/*! \brief kernel size in height */
index_t nsize_;
/*! \brief kernel size in width */
index_t kstride_;
/*! \brief pad */
index_t pad_;
/*! \brief constructor */
ChannelUnpoolingExp(const SrcExp &data_src,
const SrcExp &data_pooled,
const SrcExp &grad_pooled,
index_t nsize, index_t kstride, index_t pad)
: data_src_(data_src), data_pooled_(data_pooled),
grad_pooled_(grad_pooled),
nsize_(nsize), kstride_(kstride), pad_(pad) {
Shape<srcdim> pshape = ShapeCheck<srcdim, SrcExp>::Check(grad_pooled);
typedef ShapeCheck<srcdim, SrcExp> ShapeCheckSrcDimSrcExp;
CHECK_EQ(pshape, ShapeCheckSrcDimSrcExp::Check(data_pooled))
<< "ChannelUnPoolingExp: data and grad shape mismatch";
Shape<srcdim> sshape = ShapeCheck<srcdim, SrcExp>::Check(data_src);
for (int k = 0; k < srcdim; ++k) {
if (k == 1) {
continue;
}
CHECK_EQ(pshape[k], sshape[k])
<< "ChannelUnPoolingExp: pooled tensor and src tensor shape mismatch"
<< pshape[k]
<< " vs "
<< sshape[k];
}
pchannel_ = pshape[1];
this->shape_ = sshape;
}
};
/*!
* \brief channel unpooling, do unroll over (local nearby) channels
* \param src source data
* \param nsize neighbor size
* \param stride stride of the pooling
* \param pad number of padding at each side
* \return expression of pooled result
* \tparam Reducer reducer type
* \tparam SrcExp source expression
* \tparam DType the type of elements
* \tparam etype type of expression
*/
template<typename Reducer, typename SrcExp, typename DType, int etype>
inline ChannelUnpoolingExp<Reducer, SrcExp, DType, ExpInfo<SrcExp>::kDim>
ch_unpool(const Exp<SrcExp, DType, etype> &data_src,
const Exp<SrcExp, DType, etype> &data_pooled,
const Exp<SrcExp, DType, etype> &grad_pooled,
index_t nsize, index_t stride, index_t pad) {
TypeCheckPass<ExpInfo<SrcExp>::kDim >= 3>
::Error_Expression_Does_Not_Meet_Dimension_Req();
return ChannelUnpoolingExp<Reducer, SrcExp, DType, ExpInfo<SrcExp>::kDim>
(data_src.self(), data_pooled.self(), grad_pooled.self(), nsize, stride, pad);
}
template<typename Reducer, typename SrcExp, typename DType, int etype>
inline ChannelUnpoolingExp<Reducer, SrcExp, DType, ExpInfo<SrcExp>::kDim>
ch_unpool(const Exp<SrcExp, DType, etype> &data_src,
const Exp<SrcExp, DType, etype> &data_pooled,
const Exp<SrcExp, DType, etype> &grad_pooled, index_t nsize) {
return ch_unpool(data_src, data_pooled, grad_pooled, nsize, 1, nsize / 2);
}
//----------------------
// Execution plan
//----------------------
template<typename Reducer, typename SrcExp, typename DType, int srcdim>
struct Plan<ChannelUnpoolingExp<Reducer, SrcExp, DType, srcdim>, DType> {
public:
explicit Plan(const ChannelUnpoolingExp<Reducer, SrcExp, DType, srcdim> &e)
: data_src_(e.data_src_), data_pooled_(e.data_pooled_),
grad_pooled_(e.grad_pooled_), channel_(e.shape_[srcdim - 3]),
height_(e.shape_[srcdim - 2]), pchannel_(e.pchannel_),
hnsize_(e.nsize_), stride_(e.kstride_), pad_(e.pad_) {}
MSHADOW_XINLINE DType Eval(index_t i, index_t j) const {
using namespace std;
const DType vsrc = data_src_.Eval(i, j);
const index_t y = i % height_;
i /= height_;
const index_t c = i % channel_;
const index_t n = i / channel_;
const index_t x = j;
const index_t cstart = c < hnsize_ - pad_ ? 0
: (c - (hnsize_ - pad_) + stride_) / stride_;
const index_t cend = min((c + pad_ + stride_) / stride_, channel_);
DType val = static_cast<DType>(0);
for (index_t cc = cstart; cc < cend; ++cc) {
val += Reducer::PartialGrad(vsrc,
data_pooled_.Eval((n * pchannel_ + cc) * height_ + y, x)) *
grad_pooled_.Eval((n * pchannel_ + cc) * height_ + y, x);
}
return val;
}
private:
Plan<SrcExp, DType> data_src_, data_pooled_, grad_pooled_;
const index_t channel_, height_, pchannel_, hnsize_, stride_, pad_;
};
} // namespace expr
} // namespace mshadow
#endif // MSHADOW_EXTENSION_CHANNEL_UNPOOL_H_
//===== EXPANDED: ../mshadow/mshadow/extension/channel_unpool.h =====
//===== EXPANDING: ../mshadow/mshadow/extension/pad.h =====
/*!
* Copyright (c) 2014 by Contributors
* \file pad.h
* \brief support for pad
* \author Tianqi Chen
*/
#ifndef MSHADOW_EXTENSION_PAD_H_
#define MSHADOW_EXTENSION_PAD_H_
namespace mshadow {
namespace expr {
/*!
* \brief padding expression, pad a image with zeros
* \tparam SrcExp source expression
* \tparam DType the type of elements
* \tparam srcdim dimension of src
*/
template<typename SrcExp, typename DType, int srcdim>
struct PaddingExp:
public MakeTensorExp<PaddingExp<SrcExp, DType, srcdim>,
SrcExp, srcdim, DType> {
/*! \brief source operand */
const SrcExp &src_;
/*! \brief pad size in y */
index_t pad_y_;
/*! \brief pad size in x */
index_t pad_x_;
/*! \brief source tensor height */
index_t src_height_;
/*! \brief source tensor width */
index_t src_width_;
/*! \brief constructor */
PaddingExp(const SrcExp &src, index_t pad_y, index_t pad_x)
: src_(src), pad_y_(pad_y), pad_x_(pad_x) {
this->shape_ = ShapeCheck<srcdim, SrcExp>::Check(src_);
src_height_ = this->shape_[srcdim - 2];
src_width_ = this->shape_[srcdim - 1];
this->shape_[srcdim - 2] += pad_y * 2; // height
this->shape_[srcdim - 1] += pad_x * 2; // width
}
};
/*!
* \brief padding expression, pad a image with zeros on boundaries, padding affects shape[0], and shape[1]
* \param src original image batches
* \param pad padding size
* \return expression corresponding to padded result
* \tparam SrcExp source expression
* \tparam DType the content data type
* \tparam etype type of expression
*/
template<typename SrcExp, typename DType, int etype>
inline PaddingExp<SrcExp, DType, ExpInfo<SrcExp>::kDim>
pad(const Exp<SrcExp, DType, etype> &src, index_t pad) {
TypeCheckPass<ExpInfo<SrcExp>::kDim >= 2>
::Error_Expression_Does_Not_Meet_Dimension_Req();
return PaddingExp<SrcExp, DType, ExpInfo<SrcExp>::kDim>(src.self(), pad, pad);
}
/*!
* \brief padding expression, pad a image with zeros on boundaries, padding affects shape[0], and shape[1]
* \param src original image batches
* \param pad_y padding size in y
* \param pad_x padding size in x
* \return expression corresponding to padded result
* \tparam SrcExp source expression
* \tparam DType the content data type
* \tparam etype type of expression
*/
template<typename SrcExp, typename DType, int etype>
inline PaddingExp<SrcExp, DType, ExpInfo<SrcExp>::kDim>
pad(const Exp<SrcExp, DType, etype> &src, index_t pad_y, index_t pad_x) {
TypeCheckPass<ExpInfo<SrcExp>::kDim >= 2>
::Error_Expression_Does_Not_Meet_Dimension_Req();
return PaddingExp<SrcExp, DType, ExpInfo<SrcExp>::kDim>
(src.self(), pad_y, pad_x);
}
//----------------------
// Execution plan
//----------------------
template<typename SrcExp, typename DType, int srcdim>
struct Plan<PaddingExp<SrcExp, DType, srcdim>, DType> {
public:
explicit Plan(const PaddingExp<SrcExp, DType, srcdim> &e)
: src_(MakePlan(e.src_)),
pad_y_(e.pad_y_), pad_x_(e.pad_x_),
new_height_(e.shape_[srcdim - 2]),
src_height_(e.src_height_), src_width_(e.src_width_) {}
MSHADOW_XINLINE DType Eval(index_t i, index_t j) const {
const index_t x = j;
const index_t y = i % new_height_;
const index_t c = i / new_height_;
if (y < pad_y_ || x < pad_x_) return static_cast<DType>(0);
const index_t h = y - pad_y_;
const index_t w = x - pad_x_;
if (h < src_height_ && w < src_width_) {
return src_.Eval(c * src_height_ + h, w);
} else {
return static_cast<DType>(0);
}
}
private:
Plan<SrcExp, DType> src_;
const index_t pad_y_;
const index_t pad_x_;
const index_t new_height_;
const index_t src_height_;
const index_t src_width_;
};
} // namespace expr
} // namespace mshadow
#endif // MSHADOW_EXTENSION_PAD_H_
//===== EXPANDED: ../mshadow/mshadow/extension/pad.h =====
//===== EXPANDING: ../mshadow/mshadow/extension/crop.h =====
/*!
* Copyright (c) 2014 by Contributors
* \file crop.h
* \brief support for crop
* \author Tianqi Chen
*/
#ifndef MSHADOW_EXTENSION_CROP_H_
#define MSHADOW_EXTENSION_CROP_H_
namespace mshadow {
namespace expr {
/*!
* \brief crop expression, cut off the boundary region, reverse operation of padding
* \tparam SrcExp source expression to be pooled from
* \tparam DType the type of elements
* \tparam srcdim dimension of src
*/
template<typename SrcExp, typename DType, int srcdim>
struct CroppingExp:
public MakeTensorExp<CroppingExp<SrcExp, DType, srcdim>,
SrcExp, srcdim, DType> {
/*! \brief source operand */
const SrcExp &src_;
/*! \brief pad height */
index_t pad_height_;
/*! \brief pad height */
index_t pad_width_;
/*! \brief src height */
index_t src_height_;
/*! \brief constructor */
explicit CroppingExp(const SrcExp &src, Shape<2> cshape)
: src_(src) {
this->shape_ = ShapeCheck<srcdim, SrcExp>::Check(src_);
CHECK_GE(this->shape_[srcdim - 2], cshape[0]) << "CroppingExp: height requirement not met";
CHECK_GE(this->shape_[srcdim - 1], cshape[1]) << "CroppingExp: width requirement not met";
pad_height_ = (this->shape_[srcdim - 2] - cshape[0]) / 2;
pad_width_ = (this->shape_[srcdim - 1] - cshape[1]) / 2;
src_height_ = this->shape_[srcdim - 2];
this->shape_[srcdim - 2] = cshape[0]; // height
this->shape_[srcdim - 1] = cshape[1]; // width
}
/*! \brief constructor */
explicit CroppingExp(const SrcExp &src, Shape<2> cshape,
index_t start_height, index_t start_width)
: src_(src), pad_height_(start_height), pad_width_(start_width) {
this->shape_ = ShapeCheck<srcdim, SrcExp>::Check(src_);
CHECK_GE(this->shape_[srcdim - 2], cshape[0] + start_height)
<< "CroppingExp: height requirement not met";
CHECK_GE(this->shape_[srcdim - 1], cshape[1] + start_width)
<< "CroppingExp: width requirement not met";
src_height_ = this->shape_[srcdim - 2];
this->shape_[srcdim - 2] = cshape[0]; // height
this->shape_[srcdim - 1] = cshape[1]; // width
}
}; // struct CroppingExp
/*!
* \brief revserse operationg of padding, cut off boundaries,
* crop output from center of input
* \param src original image batches
* \param oshape output shape to be cropped
* \return expression corresponding to padded result
* \tparam SrcExp source expression
* \tparam DType the type of elements
* \tparam etype type of expression
*/
template<typename SrcExp, typename DType, int etype>
inline CroppingExp<SrcExp, DType, ExpInfo<SrcExp>::kDim>
crop(const Exp<SrcExp, DType, etype> &src, Shape<2> oshape) {
TypeCheckPass<ExpInfo<SrcExp>::kDim >= 2>
::Error_Expression_Does_Not_Meet_Dimension_Req();
return CroppingExp<SrcExp, DType, ExpInfo<SrcExp>::kDim>(src.self(), oshape);
}
/*!
* \brief same as crop, but can specify starting position to do cropping
* \param src original image batches
* \param oshape output shape to be cropped
* \param start_height start height position to do cropping
* \param start_width start width position to do cropping
* \return expression corresponding to padded result
* \tparam SrcExp source expression
* \tparam DType the type of elements
* \tparam etype type of expression
*/
template<typename SrcExp, typename DType, int etype>
inline CroppingExp<SrcExp, DType, ExpInfo<SrcExp>::kDim>
crop(const Exp<SrcExp, DType, etype> &src, Shape<2> oshape,
index_t start_height, index_t start_width) {
TypeCheckPass<ExpInfo<SrcExp>::kDim >= 2>
::Error_Expression_Does_Not_Meet_Dimension_Req();
return CroppingExp<SrcExp, DType, ExpInfo<SrcExp>::kDim>
(src.self(), oshape, start_height, start_width);
}
//----------------------
// Execution plan
//----------------------
template<typename SrcExp, typename DType, int srcdim>
struct Plan<CroppingExp<SrcExp, DType, srcdim>, DType> {
public:
explicit Plan(const CroppingExp<SrcExp, DType, srcdim> &e)
: src_(MakePlan(e.src_)),
pad_height_(e.pad_height_), pad_width_(e.pad_width_),
new_height_(e.shape_[srcdim - 2]), src_height_(e.src_height_) {}
MSHADOW_XINLINE DType Eval(index_t i, index_t j) const {
const index_t x = j;
const index_t y = i % new_height_;
const index_t c = i / new_height_;
const index_t h = y + pad_height_;
const index_t w = x + pad_width_;
return src_.Eval(c * src_height_ + h, w);
}
private:
Plan<SrcExp, DType> src_;
const index_t pad_height_, pad_width_;
const index_t new_height_;
const index_t src_height_;
};
} // namespace expr
} // namespace mshadow
#endif // MSHADOW_EXTENSION_CROP_H_
//===== EXPANDED: ../mshadow/mshadow/extension/crop.h =====
//===== EXPANDING: ../mshadow/mshadow/extension/mirror.h =====
/*!
* Copyright (c) 2014 by Contributors
* \file mirror.h
* \brief support for mirror
* \author Tianqi Chen
*/
#ifndef MSHADOW_EXTENSION_MIRROR_H_
#define MSHADOW_EXTENSION_MIRROR_H_
namespace mshadow {
namespace expr {
/*!
* \brief mirror expression, mirror a image in width
* \tparam SrcExp source expression to be mirrored
* \tparam DType the type of elements
* \tparam srcdim dimension of src
*/
template<typename SrcExp, typename DType, int srcdim>
struct MirroringExp:
public MakeTensorExp<MirroringExp<SrcExp, DType, srcdim>,
SrcExp, srcdim, DType> {
/*! \brief source operand */
const SrcExp &src_;
/*! \brief constructor */
explicit MirroringExp(const SrcExp &src) : src_(src) {
this->shape_ = ShapeCheck<srcdim, SrcExp>::Check(src_);
}
};
/*!
* \brief mirroring expression, mirror images in width
* \param src original image batches
* \return expression corresponding to mirrored result
* \tparam SrcExp source expression
* \tparam DType the type of elements
* \tparam etype type of expression
*/
template<typename SrcExp, typename DType, int etype>
inline MirroringExp<SrcExp, DType, ExpInfo<SrcExp>::kDim>
mirror(const Exp<SrcExp, DType, etype> &src) {
TypeCheckPass<ExpInfo<SrcExp>::kDim >= 2>
::Error_Expression_Does_Not_Meet_Dimension_Req();
return MirroringExp<SrcExp, DType, ExpInfo<SrcExp>::kDim>(src.self());
}
//----------------------
// Execution plan
//----------------------
template<typename SrcExp, typename DType, int srcdim>
struct Plan<MirroringExp<SrcExp, DType, srcdim>, DType> {
public:
explicit Plan(const MirroringExp<SrcExp, DType, srcdim> &e)
: src_(MakePlan(e.src_)), width_(e.shape_[srcdim - 1]) {}
MSHADOW_XINLINE DType Eval(index_t i, index_t j) const {
return src_.Eval(i, width_ - j - 1);
}
private:
Plan<SrcExp, DType> src_;
const index_t width_;
};
} // namespace expr
} // namespace mshadow
#endif // MSHADOW_EXTENSION_MIRROR_H_
//===== EXPANDED: ../mshadow/mshadow/extension/mirror.h =====
//===== EXPANDING: ../mshadow/mshadow/extension/concat.h =====
/*!
* Copyright (c) 2014 by Contributors
* \file concat.h
* \brief support for concatenation
*/
#ifndef MSHADOW_EXTENSION_CONCAT_H_
#define MSHADOW_EXTENSION_CONCAT_H_
namespace mshadow {
namespace expr {
/*!
* \brief concat expression, concat two tensor's channel
* \tparam LhsExp left expression
* \tparam RhsExp right expression
* \tparam DType the type of elements
* \tparam srcdim dimension of src
* \tparam dimsrc_m_cat dimsrc - dimcat
*/
template<typename LhsExp, typename RhsExp,
typename Device, typename DType,
int srcdim, int dimsrc_m_cat>
struct ConcatExp : public TRValue<ConcatExp<LhsExp, RhsExp,
Device, DType,
srcdim, dimsrc_m_cat>,
Device, srcdim, DType> {
static const int dimcat = srcdim - dimsrc_m_cat;
const LhsExp &src1_;
const RhsExp &src2_;
index_t dcat_src1_;
index_t dcat_src2_;
Shape<4> shape_;
ConcatExp(const LhsExp &src1, const RhsExp &src2) : src1_(src1), src2_(src2) {
Shape<srcdim> sshape1 = ShapeCheck<srcdim, LhsExp>::Check(src1_);
Shape<srcdim> sshape2 = ShapeCheck<srcdim, RhsExp>::Check(src2_);
#pragma unroll
for (int i = 0; i < srcdim; ++i) {
if (i != dimcat) {
CHECK_EQ(sshape1[i], sshape2[i]) << "ConcatExp: shape mismatch";
}
}
this->shape_ = sshape1;
this->shape_[dimcat] = sshape1[dimcat] + sshape2[dimcat];
this->dcat_src1_ = sshape1[dimcat];
this->dcat_src2_ = sshape2[dimcat];
}
template<typename E, int etype>
inline void
operator=(const expr::Exp<E, DType, etype> &exp) {
this->__assign(exp);
}
inline void
operator=(const DType &exp) {
this->__assign(exp);
}
}; // struct ConcatExp
/*!
* \brief concat two 4D tensor
* \param src1 source tensor1
* \param src2 source tensor2
* \return concated 4D tensor
* \tparam cdim the dimension to concatnate on
* \tparam SrcExp source expression
* \tparam DType the type of elements
* \tparam etype type of expression
*/
template<int cdim, typename LhsExp, typename RhsExp,
typename Device, typename DType, int srcdim>
inline ConcatExp<LhsExp, RhsExp, Device, DType, srcdim, srcdim - cdim>
concat(const TRValue<LhsExp, Device, srcdim, DType> &src1,
const TRValue<RhsExp, Device, srcdim, DType> &src2) {
TypeCheckPass<ExpInfo<LhsExp>::kDim == ExpInfo<RhsExp>::kDim>
::Error_Expression_Does_Not_Meet_Dimension_Req();
TypeCheckPass<cdim < srcdim && ExpInfo<LhsExp>::kDim == srcdim>
::Error_Expression_Does_Not_Meet_Dimension_Req();
return ConcatExp<LhsExp, RhsExp, Device, DType, srcdim, srcdim - cdim>
(src1.self(), src2.self());
}
//------------------------
// engine plugin
//------------------------
// runtime shapecheck
template<typename LhsExp, typename RhsExp,
typename Device, typename DType,
int srcdim, int dimsrc_m_cat>
struct ShapeCheck<srcdim, ConcatExp<LhsExp, RhsExp, Device, DType, srcdim, dimsrc_m_cat> >{
inline static Shape<srcdim> Check(const ConcatExp<LhsExp, RhsExp,
Device, DType, srcdim, dimsrc_m_cat> &t) {
return t.shape_;
}
};
template<typename LhsExp, typename RhsExp,
typename Device, typename DType,
int srcdim, int dimsrc_m_cat>
struct StreamInfo<Device, ConcatExp<LhsExp, RhsExp, Device, DType, srcdim, dimsrc_m_cat> >{
inline static Stream<Device> *
Get(const ConcatExp<LhsExp, RhsExp, Device, DType, srcdim, dimsrc_m_cat> &t) {
Stream<Device> *lhs = StreamInfo<Device, LhsExp>::Get(t.src1_);
Stream<Device> *rhs = StreamInfo<Device, RhsExp>::Get(t.src2_);
if (lhs != rhs) return NULL;
return lhs;
}
};
// static typecheck
template<typename LhsExp, typename RhsExp,
typename Device, typename DType,
int srcdim, int dimsrc_m_cat>
struct ExpInfo<ConcatExp<LhsExp, RhsExp, Device, DType, srcdim, dimsrc_m_cat> >{
static const int kDimLhs = ExpInfo<LhsExp>::kDim;
static const int kDimRhs = ExpInfo<RhsExp>::kDim;
// copy from binarymap
static const int kDim = (kDimLhs >= 0 && kDimRhs >= 0) ?\
(kDimLhs == 0 ?\
kDimRhs :\
((kDimRhs == 0 || kDimLhs == kDimRhs) ? kDimLhs : -1)) : -1;
static const int kDevMask = ExpInfo<LhsExp>::kDevMask & ExpInfo<RhsExp>::kDevMask;
};
//----------------------
// Execution plan
//---------------------
template<typename LhsExp, typename RhsExp,
typename Device, typename DType,
int srcdim, int dimsrc_m_cat>
struct Plan<ConcatExp<LhsExp, RhsExp, Device, DType, srcdim, dimsrc_m_cat>, DType> {
public:
static const int dimcat = srcdim - dimsrc_m_cat;
explicit Plan(const ConcatExp<LhsExp, RhsExp, Device, DType, srcdim, dimsrc_m_cat> &e)
: src1_(MakePlan(e.src1_)), src2_(MakePlan(e.src2_)),
height_(e.shape_.ProdShape(dimcat + 1, srcdim - 1)),
ch_src1_(e.dcat_src1_), ch_src2_(e.dcat_src2_), ch_(e.shape_[dimcat]) {}
MSHADOW_XINLINE DType Eval(index_t i, index_t j) const {
const index_t y = i % height_;
i /= height_;
const index_t c = i % ch_;
const index_t b = i / ch_;
const index_t x = j;
if (c < ch_src1_) {
return src1_.Eval((b * ch_src1_ + c) * height_ + y, x);
} else {
return src2_.Eval((b * ch_src2_ + c - ch_src1_) * height_ + y, x);
}
}
MSHADOW_XINLINE DType &REval(index_t i, index_t j) {
const index_t y = i % height_;
i /= height_;
const index_t c = i % ch_;
const index_t b = i / ch_;
const index_t x = j;
if (c < ch_src1_) {
return src1_.REval((b * ch_src1_ + c) * height_ + y, x);
} else {
return src2_.REval((b * ch_src2_ + c - ch_src1_) * height_ + y, x);
}
}
private:
Plan<LhsExp, DType> src1_;
Plan<RhsExp, DType> src2_;
const index_t height_, ch_src1_, ch_src2_, ch_;
}; // struct Plan
// specialize for concat in x
template<typename LhsExp, typename RhsExp,
typename Device, typename DType,
int srcdim>
struct Plan<ConcatExp<LhsExp, RhsExp, Device, DType, srcdim, 1>, DType> {
public:
explicit Plan(const ConcatExp<LhsExp, RhsExp, Device, DType, srcdim, 1> &e)
: src1_(MakePlan(e.src1_)), src2_(MakePlan(e.src2_)),
width_src1_(e.dcat_src1_) {}
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const {
if (x < width_src1_) {
return src1_.Eval(y, x);
} else {
return src2_.Eval(y, x - width_src1_);
}
}
MSHADOW_XINLINE DType &REval(index_t y, index_t x) {
if (x < width_src1_) {
return src1_.REval(y, x);
} else {
return src2_.REval(y, x - width_src1_);
}
}
private:
Plan<LhsExp, DType> src1_;
Plan<RhsExp, DType> src2_;
const index_t width_src1_;
};
} // namespace expr
} // namespace mshadow
#endif // MSHADOW_EXTENSION_CONCAT_H_
//===== EXPANDED: ../mshadow/mshadow/extension/concat.h =====
//===== EXPANDING: ../mshadow/mshadow/extension/choose.h =====
/*!
* Copyright (c) 2014 by Contributors
* \file choose.h
* \brief support for implicit array selection operation
* \author Tianqi Chen
*/
#ifndef MSHADOW_EXTENSION_CHOOSE_H_
#define MSHADOW_EXTENSION_CHOOSE_H_
namespace mshadow {
namespace expr {
/*!
* \brief Make a choice of index in the lowest changing dimension.
* \tparam SrcExp type of lhs expression
* \tparam IndexExp type of index expression
* \tparam DType the type of elements
*/
template<typename SrcExp, typename IndexExp, typename DType>
struct MatChooseRowElementExp:
public Exp<MatChooseRowElementExp<SrcExp, IndexExp, DType>,
DType, type::kChainer> {
/*! \brief source operand */
const SrcExp &src_;
/*! \brief index operand */
const IndexExp &index_;
/*! \brief constructor */
MatChooseRowElementExp(const SrcExp &src, const IndexExp &index)
: src_(src), index_(index) {}
};
template<typename SrcExp, typename IndexExp,
typename DType, typename IDType, int e1, int e2>
inline MatChooseRowElementExp<SrcExp, IndexExp, DType>
mat_choose_row_element(const Exp<SrcExp, DType, e1> &src,
const Exp<IndexExp, IDType, e2> &index) {
TypeCheckPass<ExpInfo<SrcExp>::kDim == 2 && ExpInfo<IndexExp>::kDim == 1>
::Error_Expression_Does_Not_Meet_Dimension_Req();
return MatChooseRowElementExp<SrcExp, IndexExp, DType>(src.self(), index.self());
}
//----------------------
// Execution plan
//----------------------
template<typename SrcExp, typename IndexExp, typename DType>
struct Plan<MatChooseRowElementExp<SrcExp, IndexExp, DType>, DType> {
public:
explicit Plan(const MatChooseRowElementExp<SrcExp, IndexExp, DType> &e)
: src_(MakePlan(e.src_)),
index_(MakePlan(e.index_)) {
}
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const {
index_t idx = static_cast<index_t>(index_.Eval(0, x));
return src_.Eval(x, idx);
}
private:
expr::Plan<SrcExp, DType> src_;
expr::Plan<IndexExp, DType> index_;
};
template<typename SrcExp, typename IndexExp, typename DType>
inline Plan<MatChooseRowElementExp<SrcExp, IndexExp, DType>, DType>
MakePlan(const MatChooseRowElementExp<SrcExp, IndexExp, DType> &exp) {
return Plan<MatChooseRowElementExp<SrcExp, IndexExp, DType>, DType>(exp);
}
template<int dim, typename SrcExp, typename IndexExp, typename DType>
struct ShapeCheck<dim, MatChooseRowElementExp<SrcExp, IndexExp, DType> > {
inline static Shape<dim>
Check(const MatChooseRowElementExp<SrcExp, IndexExp, DType> &t) {
CHECK(dim == 1)
<< "MatChooseRowElementExp only support 1 dimension output";
Shape<2> shape1 = ShapeCheck<2, SrcExp>::Check(t.src_);
Shape<dim> shape2 = ShapeCheck<dim, IndexExp>::Check(t.index_);
CHECK_EQ(shape1[0], shape2[0])
<< "mat_choose_row_element index length and number of rows in matrix";
return shape2;
}
};
template<typename SrcExp, typename IndexExp, typename DType>
struct ExpInfo<MatChooseRowElementExp<SrcExp, IndexExp, DType> > {
static const int kDim = 1;
static const int kDevMask = ExpInfo<SrcExp>::kDevMask & ExpInfo<IndexExp>::kDevMask;
};
} // namespace expr
} // namespace mshadow
#endif // MSHADOW_EXTENSION_CHOOSE_H_
//===== EXPANDED: ../mshadow/mshadow/extension/choose.h =====
//===== EXPANDING: ../mshadow/mshadow/extension/fill.h =====
/*!
* Copyright (c) 2015 by Contributors
* \file fill.h
* \brief support for implicit array filling operation
* \author Xingjian Shi
*/
#ifndef MSHADOW_EXTENSION_FILL_H_
#define MSHADOW_EXTENSION_FILL_H_
namespace mshadow {
namespace expr {
/*!
* \brief Set value of a specific element in each line of the data matrix.
* \tparam SrcExp type of src expression
* \tparam ValExp type of val expression
* \tparam IndexExp type of index expression
* \tparam DType the type of ret expression
*/
template<typename SrcExp, typename ValExp, typename IndexExp, typename DType>
struct MatFillRowElementExp:
public Exp<MatFillRowElementExp<SrcExp, ValExp, IndexExp, DType>,
DType, type::kChainer> {
/*! \brief src operand */
const SrcExp &src_;
const ValExp &val_;
/*! \brief index operand */
const IndexExp &index_;
/*! \brief constructor */
MatFillRowElementExp(const SrcExp &src, const ValExp &val, const IndexExp &index)
: src_(src), val_(val), index_(index) {}
};
template<typename SrcExp, typename ValExp, typename IndexExp,
typename SDType, typename VDType, typename IDType, int e1, int e2, int e3>
inline MatFillRowElementExp<SrcExp, ValExp, IndexExp, SDType>
mat_fill_row_element(const Exp<SrcExp, SDType, e1> &src,
const Exp<ValExp, VDType, e2> &val,
const Exp<IndexExp, IDType, e3> &index) {
TypeCheckPass<ExpInfo<SrcExp>::kDim == 2 && ExpInfo<ValExp>::kDim == 1
&& ExpInfo<IndexExp>::kDim == 1>::Error_Expression_Does_Not_Meet_Dimension_Req();
return MatFillRowElementExp<SrcExp, ValExp, IndexExp, SDType>(src.self(),
val.self(), index.self());
}
//----------------------
// Execution plan
//----------------------
template<typename SrcExp, typename ValExp, typename IndexExp, typename DType>
struct Plan<MatFillRowElementExp<SrcExp, ValExp, IndexExp, DType>, DType> {
public:
explicit Plan(const MatFillRowElementExp<SrcExp, ValExp, IndexExp, DType> &e)
: src_(MakePlan(e.src_)),
val_(MakePlan(e.val_)),
index_(MakePlan(e.index_)) {
}
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const {
index_t idx = static_cast<index_t>(index_.Eval(0, y));
if (idx == x) {
return static_cast<DType>(val_.Eval(0, y));
} else {
return static_cast<DType>(src_.Eval(y, x));
}
}
private:
expr::Plan<SrcExp, DType> src_;
expr::Plan<ValExp, DType> val_;
expr::Plan<IndexExp, DType> index_;
};
template<typename SrcExp, typename ValExp, typename IndexExp, typename DType>
inline Plan<MatFillRowElementExp<SrcExp, ValExp, IndexExp, DType>, DType>
MakePlan(const MatFillRowElementExp<SrcExp, ValExp, IndexExp, DType> &exp) {
return Plan<MatFillRowElementExp<SrcExp, ValExp, IndexExp, DType>, DType>(exp);
}
template<int dim, typename SrcExp, typename ValExp, typename IndexExp, typename DType>
struct ShapeCheck<dim, MatFillRowElementExp<SrcExp, ValExp, IndexExp, DType> > {
inline static Shape<dim>
Check(const MatFillRowElementExp<SrcExp, ValExp, IndexExp, DType> &t) {
CHECK(dim == 2)
<< "MatFillRowElementExp only support 2 dimension output";
Shape<2> shape_src = ShapeCheck<2, SrcExp>::Check(t.src_);
Shape<1> shape_val = ShapeCheck<1, ValExp>::Check(t.val_);
Shape<1> shape_index = ShapeCheck<1, IndexExp>::Check(t.index_);
CHECK((shape_src[0] == shape_index[0]) && (shape_index[0] == shape_val[0]))
<< "mat_fill_row_element index length, val length and number of rows in matrix";
return shape_src;
}
};
template<typename SrcExp, typename ValExp, typename IndexExp, typename DType>
struct ExpInfo<MatFillRowElementExp<SrcExp, ValExp, IndexExp, DType> > {
static const int kDim = 2;
static const int kDevMask =
ExpInfo<SrcExp>::kDevMask & ExpInfo<ValExp>::kDevMask & ExpInfo<IndexExp>::kDevMask;
};
} // namespace expr
} // namespace mshadow
#endif // MSHADOW_EXTENSION_FILL_H_
//===== EXPANDED: ../mshadow/mshadow/extension/fill.h =====
//===== EXPANDING: ../mshadow/mshadow/extension/one_hot.h =====
/*!
* Copyright (c) 2014 by Contributors
* \file one_hot.h
* \brief Create one-hot indicator array based on the index.
* \author Tianqi Chen
*/
#ifndef MSHADOW_EXTENSION_ONE_HOT_H_
#define MSHADOW_EXTENSION_ONE_HOT_H_
namespace mshadow {
namespace expr {
/*!
* \brief Create a one-hot indicator array.
* \tparam IndexExp type of index expression
* \tparam DType the type of elements
*/
template<typename IndexExp, typename DType>
struct OneHotEncodeExp:
public Exp<OneHotEncodeExp<IndexExp, DType>,
DType, type::kChainer> {
/*! \brief index operand */
const IndexExp &index_;
/*! \brief number of choices we can have. */
index_t num_choices_;
/*! \brief constructor */
OneHotEncodeExp(const IndexExp &index, index_t num_choices)
: index_(index), num_choices_(num_choices) {}
};
template<typename IndexExp,
typename IDType, int e1>
inline OneHotEncodeExp<IndexExp, default_real_t>
one_hot_encode(const Exp<IndexExp, IDType, e1> &index, index_t num_choices) {
TypeCheckPass<ExpInfo<IndexExp>::kDim == 1>
::Error_Expression_Does_Not_Meet_Dimension_Req();
return OneHotEncodeExp<IndexExp, default_real_t>(index.self(), num_choices);
}
//----------------------
// Execution plan
//----------------------
template<typename IndexExp, typename DType>
struct Plan<OneHotEncodeExp<IndexExp, DType>, DType> {
public:
explicit Plan(const OneHotEncodeExp<IndexExp, DType> &e)
: index_(MakePlan(e.index_)) {
}
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const {
index_t idx = static_cast<index_t>(index_.Eval(0, y));
return static_cast<DType>(x == idx);
}
private:
expr::Plan<IndexExp, DType> index_;
};
template<typename IndexExp, typename DType>
inline Plan<OneHotEncodeExp<IndexExp, DType>, DType>
MakePlan(const OneHotEncodeExp<IndexExp, DType> &exp) {
return Plan<OneHotEncodeExp<IndexExp, DType>, DType>(exp);
}
template<int dim, typename IndexExp, typename DType>
struct ShapeCheck<dim, OneHotEncodeExp<IndexExp, DType> > {
inline static Shape<dim>
Check(const OneHotEncodeExp<IndexExp, DType> &t) {
CHECK(dim == 2)
<< "OneHotEncodeExp only support 2 dimension output";
Shape<1> shape = ShapeCheck<1, IndexExp>::Check(t.index_);
Shape<dim> ret;
ret[0] = shape[0];
ret[1] = t.num_choices_;
return ret;
}
};
template<typename IndexExp, typename DType>
struct ExpInfo<OneHotEncodeExp<IndexExp, DType> > {
static const int kDim = 2;
static const int kDevMask = ExpInfo<IndexExp>::kDevMask;
};
} // namespace expr
} // namespace mshadow
#endif // MSHADOW_EXTENSION_ONE_HOT_H_
//===== EXPANDED: ../mshadow/mshadow/extension/one_hot.h =====
//===== EXPANDING: ../mshadow/mshadow/extension/slice.h =====
/*!
* Copyright (c) 2014 by Contributors
* \file slice.h
* \brief support for slice a certain dimension.
*/
#ifndef MSHADOW_EXTENSION_SLICE_H_
#define MSHADOW_EXTENSION_SLICE_H_
namespace mshadow {
namespace expr {
/*!
* \brief slice expression, slice a tensor's channel
* \tparam SrcExp left expression
* \tparam DType the type of elements
* \tparam srcdim dimension of src
* \tparam dimsrc_m_cat dimsrc - dimcat
*/
template<typename SrcExp,
typename Device, typename DType,
int srcdim, int dimsrc_m_slice>
struct SliceExp : public TRValue<SliceExp<SrcExp,
Device, DType,
srcdim, dimsrc_m_slice>,
Device, srcdim, DType> {
static const int dimslice = srcdim - dimsrc_m_slice;
const SrcExp &src_;
index_t ch_begin_;
index_t ch_old_;
Shape<srcdim> shape_;
SliceExp(const SrcExp &src, index_t begin, index_t end)
: src_(src), ch_begin_(begin) {
shape_ = ShapeCheck<srcdim, SrcExp>::Check(src_);
ch_old_ = shape_[dimslice];
CHECK(begin < shape_[dimslice] && end <= shape_[dimslice])
<< "The slice went out of range";
shape_[dimslice] = end - begin;
}
template<typename E, int etype>
inline void
operator=(const expr::Exp<E, DType, etype> &exp) {
this->__assign(exp);
}
inline void
operator=(const DType &exp) {
this->__assign(exp);
}
}; // struct Slice
/*!
* \brief Slice a Tensor
* \param src source tensor
* \param begin The beginning slice.
* \param end The end slice.
* \return sliced tensor
* \tparam sdim the dimension to slice on
* \tparam SrcExp source expression
* \tparam DType the type of elements
* \tparam etype type of expression
*/
template<int sdim, typename SrcExp,
typename Device, typename DType, int srcdim>
inline SliceExp<SrcExp, Device, DType, srcdim, srcdim - sdim>
slice(const TRValue<SrcExp, Device, srcdim, DType> &src, index_t begin, index_t end) {
TypeCheckPass<sdim < srcdim && ExpInfo<SrcExp>::kDim == srcdim>
::Error_Expression_Does_Not_Meet_Dimension_Req();
return SliceExp<SrcExp, Device, DType, srcdim, srcdim - sdim>(src.self(), begin, end);
}
//------------------------
// engine plugin
//------------------------
// runtime shapecheck
template<typename SrcExp,
typename Device, typename DType,
int srcdim, int dimsrc_m_slice>
struct ShapeCheck<srcdim, SliceExp<SrcExp, Device, DType, srcdim, dimsrc_m_slice> >{
inline static Shape<srcdim> Check(const SliceExp<SrcExp,
Device, DType, srcdim, dimsrc_m_slice> &t) {
return t.shape_;
}
};
template<typename SrcExp,
typename Device, typename DType,
int srcdim, int dimsrc_m_slice>
struct StreamInfo<Device, SliceExp<SrcExp, Device, DType, srcdim, dimsrc_m_slice> >{
inline static Stream<Device> *
Get(const SliceExp<SrcExp, Device, DType, srcdim, dimsrc_m_slice> &t) {
return StreamInfo<Device, SrcExp>::Get(t.src_);
}
};
// static typecheck
template<typename SrcExp,
typename Device, typename DType,
int srcdim, int dimsrc_m_slice>
struct ExpInfo<SliceExp<SrcExp, Device, DType, srcdim, dimsrc_m_slice> >{
static const int kDim = ExpInfo<SrcExp>::kDim;
static const int kDevMask = ExpInfo<SrcExp>::kDevMask;
};
//----------------------
// Execution plan
//---------------------
template<typename SrcExp,
typename Device, typename DType,
int srcdim, int dimsrc_m_slice>
struct Plan<SliceExp<SrcExp, Device, DType, srcdim, dimsrc_m_slice>, DType> {
public:
static const int dimslice = srcdim - dimsrc_m_slice;
explicit Plan(const SliceExp<SrcExp, Device, DType, srcdim, dimsrc_m_slice> &e)
: src_(MakePlan(e.src_)),
height_(e.shape_.ProdShape(dimslice + 1, srcdim - 1)),
ch_begin_(e.ch_begin_), ch_old_(e.ch_old_), ch_(e.shape_[dimslice]) {}
MSHADOW_XINLINE DType Eval(index_t i, index_t j) const {
const index_t y = i % height_;
i /= height_;
const index_t c = i % ch_ + ch_begin_;
const index_t b = i / ch_;
const index_t x = j;
return src_.Eval((b * ch_old_ + c) * height_ + y, x);
}
MSHADOW_XINLINE DType &REval(index_t i, index_t j) {
const index_t y = i % height_;
i /= height_;
const index_t c = i % ch_ + ch_begin_;
const index_t b = i / ch_;
const index_t x = j;
return src_.REval((b * ch_old_ + c) * height_ + y, x);
}
private:
Plan<SrcExp, DType> src_;
const index_t height_, ch_begin_, ch_old_, ch_;
}; // struct Plan
template<typename SrcExp,
typename Device, typename DType,
int srcdim>
struct Plan<SliceExp<SrcExp, Device, DType, srcdim, 1>, DType> {
public:
explicit Plan(const SliceExp<SrcExp, Device, DType, srcdim, 1> &e)
: src_(MakePlan(e.src_)),
ch_begin_(e.ch_begin_) {}
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const {
return src_.Eval(y, x + ch_begin_);
}
MSHADOW_XINLINE DType &REval(index_t y, index_t x) {
return src_.REval(y, x + ch_begin_);
}
private:
Plan<SrcExp, DType> src_;
const index_t ch_begin_;
};
} // namespace expr
} // namespace mshadow
#endif // MSHADOW_EXTENSION_SLICE_H_
//===== EXPANDED: ../mshadow/mshadow/extension/slice.h =====
//===== EXPANDING: ../mshadow/mshadow/extension/slice_ex.h =====
/*!
* Copyright (c) 2014 by Contributors
* \file slice.h
* \brief support for slice a certain dimension.
*/
#ifndef MSHADOW_EXTENSION_SLICE_EX_H_
#define MSHADOW_EXTENSION_SLICE_EX_H_
namespace mshadow {
namespace expr {
/*!
* \brief slice expression, slice a tensor's channel
* \tparam SrcExp left expression
* \tparam DType the type of elements
* \tparam srcdim dimension of src
* \tparam dimsrc_m_cat dimsrc - dimcat
*/
template<typename SrcExp, typename Device,
typename DType, int srcdim>
struct SliceExExp : public TRValue<SliceExExp<SrcExp,
Device, DType,
srcdim>,
Device, srcdim, DType> {
const SrcExp &src_;
Shape<srcdim> src_shape_;
Shape<srcdim> shape_;
const Shape<srcdim> begin_;
const Shape<srcdim> end_;
SliceExExp(const SrcExp &src, Shape<srcdim> begin, Shape<srcdim> end)
: src_(src), begin_(begin), end_(end) {
src_shape_ = ShapeCheck<srcdim, SrcExp>::Check(src_);
for (int i = 0; i < srcdim; ++i) {
shape_[i] = end_[i] - begin_[i];
}
}
template<typename E, int etype>
inline void
operator=(const expr::Exp<E, DType, etype> &exp) {
this->__assign(exp);
}
inline void
operator=(const DType &exp) {
this->__assign(exp);
}
}; // struct SliceEx
/*!
* \brief SliceEx a Tensor
* \param src source tensor
* \param begin The beginning slice.
* \param end The end slice.
* \return sliced tensor
* \tparam sdim the dimension to slice on
* \tparam SrcExp source expression
* \tparam DType the type of elements
* \tparam etype type of expression
*/
template<typename SrcExp, typename Device,
typename DType, int srcdim>
inline SliceExExp<SrcExp, Device, DType, srcdim>
slice(const TRValue<SrcExp, Device, srcdim, DType> &src, Shape<srcdim> begin, Shape<srcdim> end) {
TypeCheckPass<ExpInfo<SrcExp>::kDim == srcdim>
::Error_Expression_Does_Not_Meet_Dimension_Req();
return SliceExExp<SrcExp, Device, DType, srcdim>(src.self(), begin, end);
}
//------------------------
// engine plugin
//------------------------
// runtime shapecheck
template<typename SrcExp, typename Device,
typename DType, int srcdim>
struct ShapeCheck<srcdim, SliceExExp<SrcExp, Device, DType, srcdim> >{
inline static Shape<srcdim> Check(const SliceExExp<SrcExp,
Device, DType, srcdim> &t) {
return t.shape_;
}
};
template<typename SrcExp, typename Device,
typename DType, int srcdim>
struct StreamInfo<Device, SliceExExp<SrcExp, Device, DType, srcdim> >{
inline static Stream<Device> *
Get(const SliceExExp<SrcExp, Device, DType, srcdim> &t) {
return StreamInfo<Device, SrcExp>::Get(t.src_);
}
};
// static typecheck
template<typename SrcExp, typename Device,
typename DType, int srcdim>
struct ExpInfo<SliceExExp<SrcExp, Device, DType, srcdim> >{
static const int kDim = ExpInfo<SrcExp>::kDim;
static const int kDevMask = ExpInfo<SrcExp>::kDevMask;
};
//----------------------
// Execution plan
//---------------------
template<typename SrcExp, typename Device,
typename DType, int srcdim>
struct Plan<SliceExExp<SrcExp, Device, DType, srcdim>, DType> {
public:
explicit Plan(const SliceExExp<SrcExp, Device, DType, srcdim> &e)
: src_(MakePlan(e.src_)), begin_(e.begin_),
src_shape_(e.src_shape_), shape_(e.shape_) {}
MSHADOW_XINLINE DType Eval(index_t i, index_t j) const {
index_t idx = 0;
index_t stride = 1;
#pragma unroll
for (int k = srcdim-2; k >= 0; --k) {
idx += stride * (i%shape_[k] + begin_[k]);
i /= shape_[k];
stride *= src_shape_[k];
}
return src_.Eval(idx, j + begin_[srcdim-1]);
}
MSHADOW_XINLINE DType &REval(index_t i, index_t j) {
index_t idx = 0;
index_t stride = 1;
#pragma unroll
for (int k = srcdim-2; k >= 0; --k) {
idx += stride * (i%shape_[k] + begin_[k]);
i /= shape_[k];
stride *= src_shape_[k];
}
return src_.REval(idx, j + begin_[srcdim-1]);
}
private:
Plan<SrcExp, DType> src_;
const Shape<srcdim> begin_, src_shape_, shape_;
}; // struct Plan
} // namespace expr
} // namespace mshadow
#endif // MSHADOW_EXTENSION_SLICE_EX_H_
//===== EXPANDED: ../mshadow/mshadow/extension/slice_ex.h =====
//===== EXPANDING: ../mshadow/mshadow/extension/take.h =====
/*!
* Copyright (c) 2015 by Contributors
* \file take.h
* \brief
* \author Bing Xu
*/
#ifndef MSHADOW_EXTENSION_TAKE_H_
#define MSHADOW_EXTENSION_TAKE_H_
namespace mshadow {
namespace expr {
/*! \brief Take a column from a matrix
* \tparam IndexExp type of index expression
* \tparam SrcExp type of src expression
* \tparam DType data type
*/
template<typename IndexExp, typename SrcExp, typename DType>
struct TakeExp: public Exp<TakeExp<IndexExp, SrcExp, DType>,
DType, type::kChainer> {
/*! \brief index oprand */
const IndexExp &index_;
/*! \brief embediing oprand */
const SrcExp &src_;
/*! constructor */
TakeExp(const IndexExp &index, const SrcExp &src)
: index_(index), src_(src) {}
}; // struct TakeExp
template<typename IndexExp,
typename SrcExp,
typename DType,
int e1, int e2>
inline TakeExp<IndexExp, SrcExp, DType>
take(const Exp<IndexExp, DType, e1> &index,
const Exp<SrcExp, DType, e2> &src) {
return TakeExp<IndexExp, SrcExp, DType>(index.self(), src.self());
}
//----------------------
// Execution plan
//----------------------
template<typename IndexExp, typename SrcExp, typename DType>
struct Plan<TakeExp<IndexExp, SrcExp, DType>, DType> {
public:
explicit Plan(const TakeExp<IndexExp, SrcExp, DType> &e)
: index_(MakePlan(e.index_)), src_(MakePlan(e.src_)) {
}
// TODO(xx): discuss W shape: in * out or out * in
// Now I use in * out
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const {
index_t idx = static_cast<index_t>(index_.Eval(0, y));
return static_cast<DType>(src_.Eval(idx, x));
}
private:
expr::Plan<IndexExp, DType> index_;
expr::Plan<SrcExp, DType> src_;
}; // struct Plan
template<typename IndexExp, typename SrcExp, typename DType>
inline Plan<TakeExp<IndexExp, SrcExp, DType>, DType>
MakePlan(const TakeExp<IndexExp, SrcExp, DType> &exp) {
return Plan<TakeExp<IndexExp, SrcExp, DType>, DType>(exp);
}
template<int dim, typename IndexExp, typename SrcExp, typename DType>
struct ShapeCheck<dim, TakeExp<IndexExp, SrcExp, DType> > {
inline static Shape<dim>
Check(const TakeExp<IndexExp, SrcExp, DType> &t) {
CHECK(dim == 2)
<< "TakeExp only support 2D output";
Shape<1> dshape = ShapeCheck<1, IndexExp>::Check(t.index_);
Shape<2> wshape = ShapeCheck<2, SrcExp>::Check(t.src_);
Shape<dim> ret;
ret[0] = dshape[0];
ret[1] = wshape[1];
return ret;
}
};
template<typename IndexExp, typename SrcExp, typename DType>
struct ExpInfo<TakeExp<IndexExp, SrcExp, DType> > {
static const int kDim = 2;
static const int kDevMask = ExpInfo<IndexExp>::kDevMask;
};
} // namespace expr
} // namespace mshadow
#endif // MSHADOW_EXTENSION_TAKE_H_
//===== EXPANDED: ../mshadow/mshadow/extension/take.h =====
//===== EXPANDING: ../mshadow/mshadow/extension/take_grad.h =====
/*!
* Copyright (c) 2015 by Contributors
* \file take_grad.h
* \brief
* \author Bing Xu
*/
#ifndef MSHADOW_EXTENSION_TAKE_GRAD_H_
#define MSHADOW_EXTENSION_TAKE_GRAD_H_
namespace mshadow {
namespace expr {
/*! \brief Calculate embedding gradient
* \tparam IndexExp type of index expression
* \tparam SrcExp type of src expression
* \tparam DType data type
*/
template<typename IndexExp, typename SrcExp, typename DType>
struct TakeGradExp : public Exp<TakeGradExp<IndexExp, SrcExp, DType>,
DType, type::kChainer> {
/*! \brief index oprand */
const IndexExp &index_;
/*! \brief out gradient oprand */
const SrcExp &src_;
/*! \brief batch size */
const index_t input_dim_;
/*! \brief constructor */
TakeGradExp(const IndexExp &index, const SrcExp &src, const index_t input_dim)
: index_(index), src_(src), input_dim_(input_dim) {}
}; // struct TakeGradExp
template<typename IndexExp,
typename SrcExp,
typename DType,
int e1, int e2>
inline TakeGradExp<IndexExp, SrcExp, DType>
take_grad(const Exp<IndexExp, DType, e1> &index,
const Exp<SrcExp, DType, e2> &src,
const index_t input_dim) {
return TakeGradExp<IndexExp, SrcExp, DType>(index.self(),
src.self(),
input_dim);
}
//----------------------
// Execution plan
//----------------------
template<typename IndexExp, typename SrcExp, typename DType>
struct Plan<TakeGradExp<IndexExp, SrcExp, DType>, DType> {
public:
explicit Plan(const TakeGradExp<IndexExp, SrcExp, DType> &e)
: index_(MakePlan(e.index_)),
src_(MakePlan(e.src_)),
batch_size_(ShapeCheck<1, IndexExp>::Check(e.index_)[0]) {
}
// now return shape: in * out
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const {
DType ret = 0.f;
for (index_t i = 0; i < batch_size_; ++i) {
index_t idx = static_cast<index_t>(index_.Eval(0, i));
if (idx == y) {
ret += static_cast<DType>(src_.Eval(i, x));
}
}
return ret;
}
private:
expr::Plan<IndexExp, DType> index_;
expr::Plan<SrcExp, DType> src_;
const index_t batch_size_;
}; // struct Plan
template<typename IndexExp, typename SrcExp, typename DType>
inline Plan<TakeGradExp<IndexExp, SrcExp, DType>, DType>
MakePlan(const TakeGradExp<IndexExp, SrcExp, DType> &exp) {
return Plan<TakeGradExp<IndexExp, SrcExp, DType>, DType>(exp);
}
template<int dim, typename IndexExp, typename SrcExp, typename DType>
struct ShapeCheck<dim, TakeGradExp<IndexExp, SrcExp, DType> > {
inline static Shape<dim>
Check(const TakeGradExp<IndexExp, SrcExp, DType> &t) {
CHECK(dim == 2)
<< "TakeGradExp only support 2D output";
// Shape<1> dshape = ShapeCheck<1, IndexExp>::Check(t.index_);
Shape<2> gshape = ShapeCheck<2, SrcExp>::Check(t.src_);
Shape<dim> ret;
ret[0] = t.input_dim_;
ret[1] = gshape[1];
return ret;
}
}; // struct ShapeCheck
template<typename IndexExp, typename SrcExp, typename DType>
struct ExpInfo<TakeGradExp<IndexExp, SrcExp, DType> > {
static const int kDim = 2;
static const int kDevMask = ExpInfo<IndexExp>::kDevMask;
};
} // namespace expr
} // namespace mshadow
#endif // MSHADOW_EXTENSION_TAKE_GRAD_H_
//===== EXPANDED: ../mshadow/mshadow/extension/take_grad.h =====
//===== EXPANDING: ../mshadow/mshadow/extension/reduce_with_axis.h =====
/*!
* Copyright (c) 2015 by Contributors
* \file reduce_with_axis.h
* \brief
* \author Junyuan Xie
*/
#ifndef MSHADOW_EXTENSION_REDUCE_WITH_AXIS_H_
#define MSHADOW_EXTENSION_REDUCE_WITH_AXIS_H_
namespace mshadow {
namespace expr {
/*! \brief reduce out the dimension of src labeled by axis.
* \tparam Reducer type of reducer
* \tparam SrcExp type of source expression
* \tparam DType data type
*/
template<typename Reducer, typename SrcExp, typename DType, int dimsrc, bool mask, int dimdst>
struct ReduceWithAxisExp:
public MakeTensorExp<ReduceWithAxisExp<Reducer, SrcExp, DType, dimsrc, mask, dimdst>,
SrcExp, dimdst, DType> {
/*! \brief source oprand */
const SrcExp &src_;
/*! \brief size of last destination dimension */
index_t last_dst_dim_;
/*! \brief size of trailing dimensions */
index_t trailing_;
/*! \brief size of axis dimension */
index_t size_;
/*! \brief size of last src dimension */
index_t last_;
/*! constructor */
explicit ReduceWithAxisExp(const SrcExp &src, int axis)
: src_(src) {
bool keepdim = (dimsrc == dimdst);
CHECK(dimsrc > axis) << "reduce axis out of bound";
Shape<dimsrc> src_shape = ShapeCheck<dimsrc, SrcExp>::Check(src_);
for (int i = 0; i < axis; ++i) {
this->shape_[i] = src_shape[i];
}
this->size_ = src_shape[axis];
this->trailing_ = 1;
if (!keepdim) {
for (int i = axis + 1; i < dimsrc; ++i) {
this->trailing_ *= src_shape[i];
this->shape_[i - 1] = src_shape[i];
}
} else {
this->shape_[axis] = 1;
for (index_t i = axis + 1; i < dimsrc; ++i) {
this->trailing_ *= src_shape[i];
this->shape_[i] = src_shape[i];
}
}
this->last_ = src_shape[dimsrc - 1];
this->last_dst_dim_ = this->shape_[dimdst - 1];
}
}; // struct ReduceWithAxisExp
/*!
* \brief reduce out the dimension of src labeled by axis.
* \param Reducer type of the reducing operation
* \param mask whether to output the unmask indices
* \tparam SrcExp source expression
* \tparam DType data type
* \tparam etype type of the expression
*/
template<typename Reducer, bool mask, typename SrcExp, typename DType, int etype>
inline ReduceWithAxisExp<Reducer, SrcExp, DType, ExpInfo<SrcExp>::kDim, mask,
ExpInfo<SrcExp>::kDim - 1>
reduce_with_axis(const Exp<SrcExp, DType, etype> &src, int axis) {
return ReduceWithAxisExp<Reducer, SrcExp, DType, ExpInfo<SrcExp>::kDim, mask,
ExpInfo<SrcExp>::kDim- 1>(src.self(), axis);
}
/*!
* \brief reduce out the dimension of src labeled by axis, keepdim turned on.
* \param Reducer type of the reducing operation
* \param mask whether to output the unmask indices
* \tparam SrcExp source expression
* \tparam DType data type
* \tparam etype type of the expression
*/
template<typename Reducer, bool mask, typename SrcExp, typename DType, int etype>
inline ReduceWithAxisExp<Reducer, SrcExp, DType, ExpInfo<SrcExp>::kDim, mask,
ExpInfo<SrcExp>::kDim>
reduce_keepdim(const Exp<SrcExp, DType, etype> &src, int axis) {
return ReduceWithAxisExp<Reducer, SrcExp, DType, ExpInfo<SrcExp>::kDim, mask,
ExpInfo<SrcExp>::kDim>(src.self(), axis);
}
//----------------------
// Execution plan
//----------------------
template<typename Reducer, typename SrcExp, typename DType, int dimsrc, bool mask, int dimdst>
struct Plan<ReduceWithAxisExp<Reducer, SrcExp, DType, dimsrc, mask, dimdst>, DType> {
public:
explicit Plan(const ReduceWithAxisExp<Reducer, SrcExp, DType, dimsrc, mask, dimdst> &e)
: src_(MakePlan(e.src_)), last_dst_dim_(e.last_dst_dim_), trailing_(e.trailing_),
size_(e.size_), last_(e.last_) {}
MSHADOW_XINLINE DType Eval(index_t i, index_t j) const {
index_t x = (i*last_dst_dim_ + j)/trailing_;
index_t y = (i*last_dst_dim_ + j)%trailing_;
if (mask) {
index_t idx = 0;
DType res; Reducer::SetInitValue(res);
for (index_t k = 0; k < size_; ++k) {
index_t z = (x*size_+k)*trailing_+y;
DType tmp = res;
Reducer::Reduce(res, src_.Eval(z/last_, z%last_));
if (tmp != res) {
idx = k;
}
}
return static_cast<DType>(static_cast<int>(idx));
} else {
DType res; Reducer::SetInitValue(res);
for (index_t k = 0; k < size_; ++k) {
index_t z = (x*size_+k)*trailing_+y;
Reducer::Reduce(res, src_.Eval(z/last_, z%last_));
}
return res;
}
}
private:
Plan<SrcExp, DType> src_;
const index_t last_dst_dim_, trailing_, size_, last_;
};
} // namespace expr
} // namespace mshadow
#endif // MSHADOW_EXTENSION_REDUCE_WITH_AXIS_H_
//===== EXPANDED: ../mshadow/mshadow/extension/reduce_with_axis.h =====
//===== EXPANDING: ../mshadow/mshadow/extension/broadcast_with_axis.h =====
/*!
* Copyright (c) 2016 by Contributors
* \file broadcast_with_axis.h
* \brief
* \author Junyuan Xie, Xingjian Shi
*/
#ifndef MSHADOW_EXTENSION_BROADCAST_WITH_AXIS_H_
#define MSHADOW_EXTENSION_BROADCAST_WITH_AXIS_H_
namespace mshadow {
namespace expr {
/*!
* \brief Broadcasting the tensor in the given axis. If keepdim is off, insert the broadcasting dim after axis. Otherwise broadcasting axis.
* \tparam SrcExp source expression
* \tparam DType data type
* \tparam dimsrc source dimension
* \tparam dimdst destination dimension
*/
template<typename SrcExp, typename DType, int dimsrc, int dimdst>
struct BroadcastWithAxisExp:
public MakeTensorExp<BroadcastWithAxisExp<SrcExp, DType, dimsrc, dimdst>,
SrcExp, dimdst, DType> {
/*! \brief data oprand */
const SrcExp &src_;
/*! \brief size of the last dimension of dst */
index_t dst_last_;
/*! \brief product of the dimensions after the broadcasting axis */
index_t trailing_;
/*! \brief new dimension of the broadcasting axis*/
index_t size_;
/*! \brief size of the last dimension of src*/
index_t last_;
/*! constructor */
BroadcastWithAxisExp(const SrcExp &src, const int axis, const index_t size)
: src_(src), size_(size) {
bool keepdim = (dimsrc == dimdst);
Shape<dimsrc> src_shape = ShapeCheck<dimsrc, SrcExp>::Check(src_);
this->trailing_ = 1;
if (!keepdim) {
CHECK(dimsrc > axis && axis >= -1) << "broadcast axis (no keepdim) out of bound, " <<
"axis must be between -1 and" << dimsrc - 1 << ", given=" << axis << ".";
for (int i = 0; i <= axis; ++i) {
this->shape_[i] = src_shape[i];
}
this->shape_[axis + 1] = size_;
for (int i = axis + 1; i < dimsrc; ++i) {
this->trailing_ *= src_shape[i];
this->shape_[i + 1] = src_shape[i];
}
} else {
CHECK(dimdst > axis && axis >= 0) << "broadcast axis (keepdim) out of bound, " <<
"axis must be between 0 and" << dimdst - 1 << ", given=" << axis << ".";
CHECK_EQ(src_shape[axis], 1) << "Size of the dimension of the broadcasting axis must be 1" <<
" when keepdim is on, src_shape[" << axis << "]=" << src_shape[axis] << ".";
for (int i = 0; i <= axis - 1; ++i) {
this->shape_[i] = src_shape[i];
}
this->shape_[axis] = size_;
for (int i = axis + 1; i < dimdst; ++i) {
this->trailing_ *= src_shape[i];
this->shape_[i] = src_shape[i];
}
}
this->last_ = src_shape[dimsrc - 1];
this->dst_last_ = this->shape_[dimdst - 1];
}
}; // struct BroadcastWithAxisExp
/*!
* \brief Broadcasting the tensor after given axis.
* \tparam SrcExp source expression
* \tparam DType data type
* \tparam etype type of the expression
*/
template<typename SrcExp, typename DType, int etype>
inline BroadcastWithAxisExp<SrcExp, DType, ExpInfo<SrcExp>::kDim,
ExpInfo<SrcExp>::kDim + 1>
broadcast_with_axis(const Exp<SrcExp, DType, etype> &src, const int axis, const index_t size) {
return BroadcastWithAxisExp<SrcExp, DType, ExpInfo<SrcExp>::kDim,
ExpInfo<SrcExp>::kDim + 1>(src.self(), axis, size);
}
/*!
* \brief Broadcasting the tensor in the given axis (keepdim turned on)
* \tparam SrcExp source expression
* \tparam DType data type
* \tparam etype type of the expression
*/
template<typename SrcExp, typename DType, int etype>
inline BroadcastWithAxisExp<SrcExp, DType, ExpInfo<SrcExp>::kDim,
ExpInfo<SrcExp>::kDim>
broadcast_keepdim(const Exp<SrcExp, DType, etype> &src, const int axis, const index_t size) {
return BroadcastWithAxisExp<SrcExp, DType, ExpInfo<SrcExp>::kDim,
ExpInfo<SrcExp>::kDim>(src.self(), axis, size);
}
/*!
* \brief Broadcasting the tensor in multiple axes. The dimension of the source tensor
in the given axes must be 1.
* \tparam SrcExp source expression
* \tparam DType data type
* \tparam dimsrc source dimension
* \tparam axesnum number of broadcasting dimensions
*/
template<typename SrcExp, typename DType, int dimsrc>
struct BroadcastWithMultiAxesExp :
public MakeTensorExp<BroadcastWithMultiAxesExp<SrcExp, DType, dimsrc>,
SrcExp, dimsrc, DType> {
/*! \brief data oprand */
const SrcExp &src_;
/*! \brief size of the last dimension of dst */
index_t dst_last_;
/*! \brief number of broadcasting axes*/
index_t axesnum_;
/*! \brief product of the dimensions after the broadcasting axses */
Shape<dimsrc> trailings_;
/*! \brief new dimension of the broadcasting axes*/
Shape<dimsrc> sizes_;
/*! \brief size of the last dimension of src*/
index_t last_;
/*! constructor */
template<typename TShape>
BroadcastWithMultiAxesExp(const SrcExp &src, const TShape& axes, const TShape& sizes)
: src_(src) {
Shape<dimsrc> src_shape = ShapeCheck<dimsrc, SrcExp>::Check(src_);
CHECK(axes.ndim() == sizes.ndim()) << "ndim of axes and sizes must be equal.";
this->axesnum_ = axes.ndim();
CHECK(this->axesnum_ <= dimsrc) << "Number of broadcasting axes must be smaller than"
"the source ndim, number of axes=" << this->axesnum_ << " dimsrc=" << dimsrc;
for (index_t i = 0; i < this->axesnum_; i++) {
CHECK(dimsrc > axes[i]) << "broadcast axis (keepdim) out of bound, " <<
"all axes must be between 0 and" << dimsrc - 1 << ", given axes[" << i << "] = " << axes[i]
<< ".";
CHECK_EQ(src_shape[axes[i]], 1) << "Size of the dimension of the broadcasting axis must be 1"
<< ", src_shape[" << axes[i] << "]=" << src_shape[axes[i]] << ".";
if (i < this->axesnum_ - 1) {
CHECK(axes[i] < axes[i + 1]) << "The given axes must be in increasing order.";
}
}
for (index_t i = 0; i < dimsrc; i++) {
this->shape_[i] = src_shape[i];
this->sizes_[i] = 1;
this->trailings_[i] = 1;
}
for (index_t i = 0; i < this->axesnum_; i++) {
this->shape_[axes[i]] = sizes[i];
this->sizes_[i] = sizes[i];
}
for (index_t i = 0; i < this->axesnum_; i++) {
this->trailings_[i] = 1;
for (index_t j = axes[i] + 1; j < dimsrc; ++j) {
this->trailings_[i] *= this->shape_[j];
}
}
this->last_ = src_shape[dimsrc - 1];
this->dst_last_ = this->shape_[dimsrc - 1];
}
}; // struct BroadcastWithMultiAxesExp
/*!
* \brief Broadcasting the tensor in the given axis (keepdim turned on)
* \param src source
* \param axes broadcasting axes
* \param sizes sizes of the broadcasting axes
* \tparam SrcExp source expression
* \tparam DType data type
* \tparam etype type of the expression
* \tparam TShape the flexible shape type
*/
template<typename SrcExp, typename DType, int etype, typename TShape>
inline BroadcastWithMultiAxesExp<SrcExp, DType, ExpInfo<SrcExp>::kDim>
broadcast_multi_axes(const Exp<SrcExp, DType, etype> &src,
const TShape &axes, const TShape &sizes) {
return BroadcastWithMultiAxesExp<SrcExp, DType, ExpInfo<SrcExp>::kDim>(src.self(), axes, sizes);
}
/*!
* \brief Broadcasting the tensor to the target shape,
dimension of different sizes must be 1 in the original tensor.
* \param src source
* \param target_shape shape of the target broadcasting tensor
* \tparam SrcExp source expression
* \tparam DType data type
* \tparam etype type of the expression
* \tparam TShape the flexible shape type
*/
template<typename SrcExp, typename DType, int etype, typename TShape>
inline BroadcastWithMultiAxesExp<SrcExp, DType, ExpInfo<SrcExp>::kDim>
broadcast_to(const Exp<SrcExp, DType, etype> &src, const TShape &target_shape) {
static const int dimsrc = ExpInfo<SrcExp>::kDim;
CHECK_EQ(target_shape.ndim(), dimsrc);
std::vector<index_t> axes_vec, sizes_vec;
Shape<dimsrc> src_shape = ShapeCheck<dimsrc, SrcExp>::Check(src.self());
for (int i = 0; i < dimsrc; ++i) {
if (src_shape[i] != target_shape[i]) {
CHECK_EQ(src_shape[i], 1) << "broadcasting axis must have size 1, received shape="
<< src_shape << " target_shape=" << target_shape;
axes_vec.push_back(i);
sizes_vec.push_back(target_shape[i]);
}
}
TShape axes = TShape(axes_vec.begin(), axes_vec.end());
TShape sizes = TShape(sizes_vec.begin(), sizes_vec.end());
return BroadcastWithMultiAxesExp<SrcExp, DType, ExpInfo<SrcExp>::kDim>(src.self(), axes, sizes);
}
//----------------------
// Execution plan
//----------------------
template<typename SrcExp, typename DType, int dimsrc, int dimdst>
struct Plan<BroadcastWithAxisExp<SrcExp, DType, dimsrc, dimdst>, DType> {
public:
explicit Plan(const BroadcastWithAxisExp<SrcExp, DType, dimsrc, dimdst> &e)
: src_(MakePlan(e.src_)), dst_last_(e.dst_last_),
trailing_(e.trailing_), size_(e.size_), last_(e.last_) {}
MSHADOW_XINLINE DType Eval(index_t i, index_t j) const {
index_t x = (i * dst_last_ + j) / trailing_ / size_;
index_t y = (i * dst_last_ + j) % trailing_;
index_t z = x * trailing_ + y;
return src_.Eval(z / last_, z % last_);
}
private:
Plan<SrcExp, DType> src_;
const index_t dst_last_, trailing_, size_, last_;
};
template<typename SrcExp, typename DType, int dimsrc>
struct Plan<BroadcastWithMultiAxesExp<SrcExp, DType, dimsrc>, DType> {
public:
explicit Plan(const BroadcastWithMultiAxesExp<SrcExp, DType, dimsrc> &e)
: src_(MakePlan(e.src_)), dst_last_(e.dst_last_), last_(e.last_), axesnum_(e.axesnum_),
trailings_(e.trailings_), sizes_(e.sizes_) {}
MSHADOW_XINLINE DType Eval(index_t i, index_t j) const {
index_t indx = i * dst_last_ + j;
for (index_t p = 0; p < dimsrc; ++p) {
if (p >= axesnum_) {
break;
}
indx = (indx / trailings_[p] / sizes_[p]) * trailings_[p] + (indx % trailings_[p]);
}
return src_.Eval(indx / last_, indx % last_);
}
private:
Plan<SrcExp, DType> src_;
const index_t dst_last_, last_, axesnum_;
const Shape<dimsrc> trailings_, sizes_;
};
} // namespace expr
} // namespace mshadow
#endif // MSHADOW_EXTENSION_BROADCAST_WITH_AXIS_H_
//===== EXPANDED: ../mshadow/mshadow/extension/broadcast_with_axis.h =====
//===== EXPANDING: ../mshadow/mshadow/extension/spatial_upsampling_nearest.h =====
/*!
* Copyright (c) 2015 by Contributors
* \file spatial_upsampling.h
* \brief
* \author Bing Xu
*/
#ifndef MSHADOW_EXTENSION_SPATIAL_UPSAMPLING_NEAREST_H_
#define MSHADOW_EXTENSION_SPATIAL_UPSAMPLING_NEAREST_H_
namespace mshadow {
namespace expr {
/*! \brief nearest neighboor upsampling
* out(x, y) = in(int(x / scale_x), int(y / scale_y))
* \tparam SrcExp source expression
* \tparam DType data type
* \tparam srcdim source dimension
*/
template<typename SrcExp, typename DType, int srcdim>
struct UpSamplingNearestExp :
public MakeTensorExp<UpSamplingNearestExp<SrcExp, DType, srcdim>,
SrcExp, srcdim, DType> {
/*! \brief source oprand */
const SrcExp &src_;
/*! \brief up sampling scale */
index_t scale_;
/*! \brief constructor */
UpSamplingNearestExp(const SrcExp &src, index_t scale)
: src_(src), scale_(scale) {
this->shape_ = ShapeCheck<srcdim, SrcExp>::Check(src_);
this->shape_[srcdim - 2] *= scale_;
this->shape_[srcdim - 1] *= scale_;
}
};
template<typename SrcExp, typename DType, int etype>
inline UpSamplingNearestExp<SrcExp, DType, ExpInfo<SrcExp>::kDim>
upsampling_nearest(const Exp<SrcExp, DType, etype> &src, index_t scale) {
TypeCheckPass<ExpInfo<SrcExp>::kDim >= 2>
::Error_Expression_Does_Not_Meet_Dimension_Req();
return UpSamplingNearestExp<SrcExp, DType, ExpInfo<SrcExp>::kDim>(src.self(), scale);
}
template<typename SrcExp, typename DType, int srcdim>
struct Plan<UpSamplingNearestExp<SrcExp, DType, srcdim>, DType> {
public:
explicit Plan(const UpSamplingNearestExp<SrcExp, DType, srcdim> &e)
: src_(MakePlan(e.src_)),
scale_(e.scale_),
new_height_(e.shape_[srcdim - 2]),
src_height_(static_cast<index_t>(e.shape_[srcdim - 2] / e.scale_)) {}
MSHADOW_XINLINE DType Eval(index_t i, index_t j) const {
const index_t x = j;
const index_t y = i % new_height_;
const index_t c = i / new_height_;
const index_t h = static_cast<index_t>(y / scale_);
const index_t w = static_cast<index_t>(x / scale_);
return src_.Eval(c * src_height_ + h, w);
}
private:
Plan<SrcExp, DType> src_;
const index_t scale_;
const index_t new_height_;
const index_t src_height_;
};
} // namespace expr
} // namespace mshadow
#endif // MSHADOW_EXTENSION_SPATIAL_UPSAMPLING_NEAREST_H_
//===== EXPANDED: ../mshadow/mshadow/extension/spatial_upsampling_nearest.h =====
//===== EXPANDING: ../mshadow/mshadow/extension/transpose.h =====
/*!
* Copyright (c) 2016 by Contributors
* \file transpose.h
* \brief support for transpose
* \author Junyuan Xie
*/
#ifndef MSHADOW_EXTENSION_TRANSPOSE_H_
#define MSHADOW_EXTENSION_TRANSPOSE_H_
namespace mshadow {
namespace expr {
/*!
* \brief transpose axes of a tensor
* input: Tensor<Device,dim>: ishape
* output: Tensor<Device,dimdst> oshape[a1],oshape[a2] = ishape[a2],oshape[a1]
*
* \tparam SrcExp type of source expression
* \tparam DType the type of elements
* \tparam dimsrc source dimension, assert a1 > a2
* \tparam m_a1 one dimension to be swapped, encoded by dimsrc - a1
* \tparam a2 second dimension to be swapped, encoded by a2
*/
template<typename SrcExp, typename DType, int dimsrc>
struct TransposeExExp:
public MakeTensorExp<TransposeExExp<SrcExp, DType, dimsrc>,
SrcExp, dimsrc, DType> {
/*! \brief source expression */
const SrcExp &src_;
const Shape<dimsrc> axes_;
Shape<dimsrc> dst_in_src_stride_; // Holds the corresponding stride of the dst axes in src
index_t src_stride_;
/*! \brief constructor */
explicit TransposeExExp(const SrcExp &src, Shape<dimsrc> axes) : src_(src), axes_(axes) {
Shape<dimsrc> src_shape = ShapeCheck<dimsrc, SrcExp>::Check(src);
src_stride_ = src_shape[dimsrc - 1];
Shape<dimsrc> src_stride;
src_stride[dimsrc-1] = 1;
for (int i = dimsrc-2; i >= 0; --i) src_stride[i] = src_shape[i+1]*src_stride[i+1];
for (int i = 0; i < dimsrc; ++i) {
dst_in_src_stride_[i] = src_stride[axes[i]];
this->shape_[i] = src_shape[axes[i]];
}
}
};
/*!
* \brief a expression that reshapes a tensor to another shape
* \param src Tensor<Device,dimsrc>:
* \return a expresion with type Tensor<Device,dimdst>
* \tparam a1 higher dimension to be swapped, assert a1 > a2
* \tparam a2 lower dimension to be swapped
* \tparam SrcExp source expression
* \tparam DType the type of elements
* \tparam etype source expression type
*/
template<typename SrcExp, typename DType, int etype>
inline TransposeExExp<SrcExp, DType, ExpInfo<SrcExp>::kDim>
transpose(const Exp<SrcExp, DType, etype> &src, Shape<ExpInfo<SrcExp>::kDim> axes) {
return TransposeExExp<SrcExp, DType, ExpInfo<SrcExp>::kDim>(src.self(), axes);
}
template<typename SrcExp, typename DType, int dimsrc>
struct Plan<TransposeExExp<SrcExp, DType, dimsrc>, DType> {
public:
explicit Plan(const TransposeExExp<SrcExp, DType, dimsrc> &e)
: src_(MakePlan(e.src_)),
src_stride_(e.src_stride_),
dst_in_src_stride_(e.dst_in_src_stride_),
dst_shape_(e.shape_) {}
MSHADOW_XINLINE DType Eval(index_t i, index_t j) const {
index_t idx = j * dst_in_src_stride_[dimsrc - 1];
#pragma unroll
for (int k = dimsrc-2; k >= 0; --k) {
idx += (i % dst_shape_[k]) * dst_in_src_stride_[k];
i /= dst_shape_[k];
}
return src_.Eval(idx/src_stride_, idx%src_stride_);
}
private:
Plan<SrcExp, DType> src_;
const index_t src_stride_;
const Shape<dimsrc> dst_in_src_stride_, dst_shape_;
};
/*!
* \brief transform contiguous indices of the source tensor to indices of the transposed tensor.
* input: Tensor<Device, k>: ishape
* output: Tensor<Device, k>: oshape = ishape
*
* \tparam SrcExp type of source expression
* \tparam DType the type of elements
* \tparam dimsrc source dimension
* \tparam etype source type
*/
template<typename SrcExp, typename DType, int dimsrc, int etype>
struct TransposeIndicesExp:
public Exp<TransposeIndicesExp<SrcExp, DType, dimsrc, etype>, DType, etype> {
/*! \brief source expression */
const SrcExp &src_indices_; // Expression of the source indices
Shape<dimsrc> src_shape_; // Holds the corresponding stride of the source axes in dst
const Shape<dimsrc> axes_; // The transpose axes
Shape<dimsrc> src_in_dst_stride_; // Holds the corresponding stride of the source axes in dst
/*! \brief constructor */
explicit TransposeIndicesExp(const SrcExp &src_indices,
Shape<dimsrc> src_shape,
Shape<dimsrc> axes) : src_indices_(src_indices),
src_shape_(src_shape), axes_(axes) {
Shape<dimsrc> dst_shape_;
Shape<dimsrc> dst_stride_;
bool axes_checking_flag[dimsrc] = { 0 };
for (int i = 0; i < dimsrc; ++i) {
CHECK_LT(axes[i], dimsrc)
<< "Invalid axes input! All elements of axes must be between 0 and " << dimsrc
<< ", find axes=" << axes;
dst_shape_[i] = src_shape[axes[i]];
axes_checking_flag[axes[i]] = true;
}
// check if the input axes is valid
for (int i = 0; i < dimsrc; ++i) {
CHECK_EQ(axes_checking_flag[i], true)
<< "Invalid axes input! All elements of axes must be between 0 and " << dimsrc
<< ", find axes=" << axes;
}
dst_stride_[dimsrc - 1] = 1;
for (int i = dimsrc - 2; i >= 0; --i) dst_stride_[i] = dst_shape_[i+1] * dst_stride_[i+1];
for (int i = 0; i < dimsrc; ++i) {
src_in_dst_stride_[axes[i]] = dst_stride_[i];
}
}
};
/*!
* \brief a expression that reshapes a tensor to another shape
* \param src Tensor<Device,dimsrc>:
* \return a expresion with type Tensor<Device,dimdst>
* \tparam a1 higher dimension to be swapped, assert a1 > a2
* \tparam a2 lower dimension to be swapped
* \tparam SrcExp source expression
* \tparam DType the type of elements
* \tparam etype source expression type
*/
template<typename SrcExp, typename DType, int dimsrc, int etype>
inline TransposeIndicesExp<SrcExp, DType, dimsrc, etype>
transpose_indices(const Exp<SrcExp, DType, etype> &src_indices,
Shape<dimsrc> src_shape,
Shape<dimsrc> axes) {
return TransposeIndicesExp<SrcExp, DType, dimsrc, etype>(src_indices.self(), src_shape, axes);
}
template<typename SrcExp, typename DType, int dimsrc, int etype>
struct Plan<TransposeIndicesExp<SrcExp, DType, dimsrc, etype>, DType> {
public:
explicit Plan(const TransposeIndicesExp<SrcExp, DType, dimsrc, etype> &e)
: src_indices_(MakePlan(e.src_indices_)),
src_in_dst_stride_(e.src_in_dst_stride_),
src_shape_(e.src_shape_) {}
MSHADOW_XINLINE DType Eval(index_t i, index_t j) const {
index_t src_idx = static_cast<index_t>(src_indices_.Eval(i, j));
index_t dst_idx = 0;
#pragma unroll
for (int k = dimsrc - 1; k >= 0; --k) {
dst_idx += (src_idx % src_shape_[k]) * src_in_dst_stride_[k];
src_idx /= src_shape_[k];
}
return static_cast<DType>(dst_idx);
}
private:
Plan<SrcExp, DType> src_indices_;
const Shape<dimsrc> src_in_dst_stride_, src_shape_;
};
//----------------------
// Execution plan
//----------------------
/*! \brief make expression */
template<typename SrcExp, typename DType, int dimsrc, int etype>
inline Plan<TransposeIndicesExp<SrcExp, DType, dimsrc, etype>, DType>
MakePlan(const TransposeIndicesExp<SrcExp, DType, dimsrc, etype> &e) {
return Plan<TransposeIndicesExp<SrcExp, DType, dimsrc, etype>, DType>(e);
}
template<int dim, typename SrcExp, typename DType, int dimsrc, int etype>
struct ShapeCheck<dim, TransposeIndicesExp<SrcExp, DType, dimsrc, etype> > {
inline static Shape<dim>
Check(const TransposeIndicesExp<SrcExp, DType, dimsrc, etype> &t) {
Shape<dim> s = ShapeCheck<dim, SrcExp>::Check(t.src_indices_);
return s;
}
};
template<typename SrcExp, typename DType, int dimsrc, int etype>
struct ExpInfo<TransposeIndicesExp<SrcExp, DType, dimsrc, etype> > {
static const int kDim = ExpInfo<SrcExp>::kDim;
static const int kDevMask = ExpInfo<SrcExp>::kDevMask;
};
} // namespace expr
} // namespace mshadow
#endif // MSHADOW_EXTENSION_TRANSPOSE_H_
//===== EXPANDED: ../mshadow/mshadow/extension/transpose.h =====
//===== EXPANDING: ../mshadow/mshadow/extension/flip.h =====
/*!
* Copyright (c) 2016 by Contributors
* \file flip.h
* \brief support for flip a certain dimension.
* \author Junyuan Xie
*/
#ifndef MSHADOW_EXTENSION_FLIP_H_
#define MSHADOW_EXTENSION_FLIP_H_
namespace mshadow {
namespace expr {
/*!
* \brief slice expression, slice a tensor's channel
* \tparam SrcExp left expression
* \tparam DType the type of elements
* \tparam srcdim dimension of src
* \tparam dimsrc_m_cat dimsrc - dimcat
*/
template<typename SrcExp, typename Device,
typename DType, int srcdim>
struct FlipExp : public TRValue<FlipExp<SrcExp,
Device, DType,
srcdim>,
Device, srcdim, DType> {
const SrcExp &src_;
index_t trailing_;
index_t stride_;
index_t stride_j_;
Shape<srcdim> shape_;
FlipExp(const SrcExp &src, int dim)
: src_(src) {
shape_ = ShapeCheck<srcdim, SrcExp>::Check(src_);
stride_ = shape_[dim];
stride_j_ = shape_[srcdim-1];
trailing_ = 1;
for (int i = dim + 1; i < srcdim; ++i) {
trailing_ *= shape_[i];
}
}
template<typename E, int etype>
inline void
operator=(const expr::Exp<E, DType, etype> &exp) {
this->__assign(exp);
}
inline void
operator=(const DType &exp) {
this->__assign(exp);
}
}; // struct Flip
/*!
* \brief Flip a Tensor
* \param src source tensor
* \param begin The beginning slice.
* \param end The end slice.
* \return sliced tensor
* \tparam sdim the dimension to slice on
* \tparam SrcExp source expression
* \tparam DType the type of elements
* \tparam etype type of expression
*/
template<typename SrcExp, typename Device,
typename DType, int srcdim>
inline FlipExp<SrcExp, Device, DType, srcdim>
flip(const TRValue<SrcExp, Device, srcdim, DType> &src, int dim) {
return FlipExp<SrcExp, Device, DType, srcdim>(src.self(), dim);
}
//------------------------
// engine plugin
//------------------------
// runtime shapecheck
template<typename SrcExp, typename Device,
typename DType, int srcdim>
struct ShapeCheck<srcdim, FlipExp<SrcExp, Device, DType, srcdim> >{
inline static Shape<srcdim> Check(const FlipExp<SrcExp,
Device, DType, srcdim> &t) {
return t.shape_;
}
};
template<typename SrcExp, typename Device,
typename DType, int srcdim>
struct StreamInfo<Device, FlipExp<SrcExp, Device, DType, srcdim> >{
inline static Stream<Device> *
Get(const FlipExp<SrcExp, Device, DType, srcdim> &t) {
return StreamInfo<Device, SrcExp>::Get(t.src_);
}
};
// static typecheck
template<typename SrcExp, typename Device,
typename DType, int srcdim>
struct ExpInfo<FlipExp<SrcExp, Device, DType, srcdim> >{
static const int kDim = ExpInfo<SrcExp>::kDim;
static const int kDevMask = ExpInfo<SrcExp>::kDevMask;
};
//----------------------
// Execution plan
//---------------------
template<typename SrcExp, typename Device,
typename DType, int srcdim>
struct Plan<FlipExp<SrcExp, Device, DType, srcdim>, DType> {
public:
explicit Plan(const FlipExp<SrcExp, Device, DType, srcdim> &e)
: src_(MakePlan(e.src_)), stride_j_(e.stride_j_),
trailing_(e.trailing_), stride_(e.stride_) {}
MSHADOW_XINLINE DType Eval(index_t i, index_t j) const {
index_t idx = i*stride_j_+j;
const index_t low = idx%trailing_;
index_t high = idx/trailing_;
const index_t x = high%stride_;
high /= stride_;
idx = (high*stride_+stride_-1-x)*trailing_+low;
return src_.Eval(idx/stride_j_, idx%stride_j_);
}
MSHADOW_XINLINE DType &REval(index_t i, index_t j) const {
index_t idx = i*stride_j_+j;
const index_t low = idx%trailing_;
index_t high = idx/trailing_;
const index_t x = high%stride_;
high /= stride_;
idx = (high*stride_+stride_-1-x)*trailing_+low;
return src_.REval(idx/stride_j_, idx%stride_j_);
}
private:
Plan<SrcExp, DType> src_;
const index_t stride_j_, trailing_, stride_;
}; // struct Plan
} // namespace expr
} // namespace mshadow
#endif // MSHADOW_EXTENSION_FLIP_H_
//===== EXPANDED: ../mshadow/mshadow/extension/flip.h =====
//===== EXPANDING: ../mshadow/mshadow/extension/complex.h =====
/*!
* Copyright (c) 2016 by Contributors
* \file complex.h
* \brief support for complex operations
* \author Xingjian Shi
*/
#ifndef MSHADOW_EXTENSION_COMPLEX_H_
#define MSHADOW_EXTENSION_COMPLEX_H_
namespace mshadow {
namespace op {
namespace complex {
enum BinaryCalculationType { kBinaryCC, kBinaryCR, kBinaryRC};
enum UnitaryCalculationType { kUnitaryC2R, kUnitaryC2C };
struct mul {
/*! \brief map a_real, a_imag, b_real, b_imag to result using defined operation */
template<typename DType>
MSHADOW_XINLINE static DType RealMap(DType a_real, DType a_imag,
DType b_real, DType b_imag) {
return a_real * b_real - a_imag * b_imag;
}
template<typename DType>
MSHADOW_XINLINE static DType ImagMap(DType a_real, DType a_imag,
DType b_real, DType b_imag) {
return a_real * b_imag + b_real * a_imag;
}
};
struct div {
/*! \brief map a_real, a_imag, b_real, b_imag to result using defined operation */
template<typename DType>
MSHADOW_XINLINE static DType RealMap(DType a_real, DType a_imag,
DType b_real, DType b_imag) {
return (a_real * b_real + a_imag * b_imag) / (b_real * b_real + b_imag * b_imag);
}
template<typename DType>
MSHADOW_XINLINE static DType ImagMap(DType a_real, DType a_imag,
DType b_real, DType b_imag) {
return (b_real * a_imag - a_real * b_imag) / (b_real * b_real + b_imag * b_imag);
}
};
struct conjugate {
template<typename TA, typename DType>
MSHADOW_XINLINE static DType RealMap(const expr::Plan<TA, DType> &src_,
index_t real_i, index_t real_j, index_t imag_i, index_t imag_j) {
return src_.Eval(real_i, real_j);
}
template<typename TA, typename DType>
MSHADOW_XINLINE static DType ImagMap(const expr::Plan<TA, DType> &src_,
index_t real_i, index_t real_j, index_t imag_i, index_t imag_j) {
return -src_.Eval(imag_i, imag_j);
}
};
struct exchange {
template<typename TA, typename DType>
MSHADOW_XINLINE static DType RealMap(const expr::Plan<TA, DType> &src_,
index_t real_i, index_t real_j, index_t imag_i, index_t imag_j) {
return src_.Eval(imag_i, imag_j);
}
template<typename TA, typename DType>
MSHADOW_XINLINE static DType ImagMap(const expr::Plan<TA, DType> &src_,
index_t real_i, index_t real_j, index_t imag_i, index_t imag_j) {
return src_.Eval(real_i, real_j);
}
};
struct abs_square {
template<typename TA, typename DType>
MSHADOW_XINLINE static DType RealMap(const expr::Plan<TA, DType> &src_,
index_t real_i, index_t real_j, index_t imag_i, index_t imag_j) {
DType real_val = src_.Eval(real_i, real_j);
DType image_val = src_.Eval(imag_i, imag_j);
return real_val * real_val + image_val * image_val;
}
};
struct sum_real_imag {
template<typename TA, typename DType>
MSHADOW_XINLINE static DType RealMap(const expr::Plan<TA, DType> &src_,
index_t real_i, index_t real_j, index_t imag_i, index_t imag_j) {
DType real_val = src_.Eval(real_i, real_j);
DType image_val = src_.Eval(imag_i, imag_j);
return real_val + image_val;
}
};
} // namespace complex
} // namespace op
namespace expr {
//--------------------
// ComplexBinaryMapExp
//--------------------
/*!
* \brief binary map expression lhs [op] rhs where lhs and rhs are complex tensors
* \tparam OP operator
* \tparam calctype type of the calculation
* \tparam TA type of lhs
* \tparam TB type of rhs
* \tparam etype expression type, sa namespace::type
*/
template<int calctype, typename OP, typename TA, typename TB, typename DType, int etype>
struct ComplexBinaryMapExp : public Exp<ComplexBinaryMapExp<calctype, OP, TA, TB, DType, etype>,
DType, etype> {
/*! \brief left operand */
const TA &lhs_;
/*! \brief right operand */
const TB &rhs_;
/*! \brief constructor */
explicit ComplexBinaryMapExp(const TA &lhs, const TB &rhs)
:lhs_(lhs), rhs_(rhs) {}
};
//-------------------
// ComplexConjExp
//-------------------
/*!
* \brief compute conj(src) where src is a complex tensor
* \tparam TA type of src
* \tparam etype expression type, sa namespace::type
*/
template<int calctype, typename OP, typename TA, typename DType, int etype>
struct ComplexUnitaryExp : public Exp<ComplexUnitaryExp<calctype, OP, TA, DType, etype>,
DType, etype> {
/*! \brief source expression */
const TA &src_;
/*! \brief constructor */
explicit ComplexUnitaryExp(const TA &src) : src_(src) {}
};
template<int calctype, typename OP, typename TA, typename TB, typename DType, int ta, int tb>
inline ComplexBinaryMapExp<calctype, OP, TA, TB, DType, (ta | tb | type::kMapper)>
ComplexF(const Exp<TA, DType, ta> &lhs, const Exp<TB, DType, tb> &rhs) {
return ComplexBinaryMapExp<calctype, OP, TA, TB, DType,
(ta | tb | type::kMapper)>(lhs.self(), rhs.self());
}
/*!
* \brief conj Negation the imaginary part of A where A is a complex tensor
* \param src source tensor
* \tparam e1 type of source expression
*/
template<int calctype, typename OP, typename SrcExp, typename DType, int e1>
inline ComplexUnitaryExp<calctype, OP, SrcExp, DType, (e1 | type::kMapper)>
ComplexF(const Exp<SrcExp, DType, e1> &src) {
return ComplexUnitaryExp<calctype, OP, SrcExp, DType, (e1 | type::kMapper)>(src.self());
}
/*!
* \brief complex_mul_cc Complex multipilication two complex tensors, A * B
*/
template<typename TA, typename TB, typename DType, int ta, int tb>
inline ComplexBinaryMapExp<op::complex::kBinaryCC, op::complex::mul,
TA, TB, DType, (ta | tb | type::kMapper)>
complex_mul_cc(const Exp<TA, DType, ta> &lhs, const Exp<TB, DType, tb> &rhs) {
return ComplexF<op::complex::kBinaryCC, op::complex::mul>(lhs, rhs);
}
/*!
* \brief complex_mul_cr Complex multipilication a complex tensor A and a real tensor B
*/
template<typename TA, typename TB, typename DType, int ta, int tb>
inline ComplexBinaryMapExp<op::complex::kBinaryCR, op::complex::mul,
TA, TB, DType, (ta | tb | type::kMapper)>
complex_mul_cr(const Exp<TA, DType, ta> &lhs, const Exp<TB, DType, tb> &rhs) {
return ComplexF<op::complex::kBinaryCR, op::complex::mul>(lhs, rhs);
}
/*!
* \brief complex_mul_rc Complex multipilication of a real tensor B and a complex tensor A
*/
template<typename TA, typename TB, typename DType, int ta, int tb>
inline ComplexBinaryMapExp<op::complex::kBinaryRC, op::complex::mul,
TA, TB, DType, (ta | tb | type::kMapper)>
complex_mul_rc(const Exp<TA, DType, ta> &lhs, const Exp<TB, DType, tb> &rhs) {
return ComplexF<op::complex::kBinaryRC, op::complex::mul>(lhs, rhs);
}
/*!
* \brief complex_mul_cc Complex multipilication two complex tensors, A * B
*/
template<typename TA, typename TB, typename DType, int ta, int tb>
inline ComplexBinaryMapExp<op::complex::kBinaryCC, op::complex::div,
TA, TB, DType, (ta | tb | type::kMapper)>
complex_div_cc(const Exp<TA, DType, ta> &lhs, const Exp<TB, DType, tb> &rhs) {
return ComplexF<op::complex::kBinaryCC, op::complex::div>(lhs, rhs);
}
/*!
* \brief complex_mul_cr Complex multipilication a complex tensor A and a real tensor B
*/
template<typename TA, typename TB, typename DType, int ta, int tb>
inline ComplexBinaryMapExp<op::complex::kBinaryCR, op::complex::div,
TA, TB, DType, (ta | tb | type::kMapper)>
complex_div_cr(const Exp<TA, DType, ta> &lhs, const Exp<TB, DType, tb> &rhs) {
return ComplexF<op::complex::kBinaryCR, op::complex::div>(lhs, rhs);
}
/*!
* \brief complex_mul_rc Complex multipilication of a real tensor A and a complex tensor B
*/
template<typename TA, typename TB, typename DType, int ta, int tb>
inline ComplexBinaryMapExp<op::complex::kBinaryRC, op::complex::div,
TA, TB, DType, (ta | tb | type::kMapper)>
complex_div_rc(const Exp<TA, DType, ta> &lhs, const Exp<TB, DType, tb> &rhs) {
return ComplexF<op::complex::kBinaryRC, op::complex::div>(lhs, rhs);
}
/*!
* \brief conj Negation the imaginary part of A where A is a complex tensor
* \param src source tensor
* \tparam e1 type of source expression
*/
template<typename SrcExp, typename DType, int e1>
inline ComplexUnitaryExp<op::complex::kUnitaryC2C, op::complex::conjugate,
SrcExp, DType, (e1|type::kMapper)>
conj(const Exp<SrcExp, DType, e1> &src) {
return ComplexF<op::complex::kUnitaryC2C, op::complex::conjugate>(src);
}
/*!
* \brief complex_exchange Exchange the real and imaginary part of A where A is a complex tensor
* \param src source tensor
* \tparam e1 type of source expression
*/
template<typename SrcExp, typename DType, int e1>
inline ComplexUnitaryExp<op::complex::kUnitaryC2C, op::complex::exchange,
SrcExp, DType, (e1|type::kMapper)>
complex_exchange(const Exp<SrcExp, DType, e1> &src) {
return ComplexF<op::complex::kUnitaryC2C, op::complex::exchange>(src);
}
/*!
* \brief complex_abs_square calculate the square of the modulus of A where A is a complex tensor
* \param src source tensor
* \tparam e1 type of source expression
*/
template<typename SrcExp, typename DType, int e1>
inline ComplexUnitaryExp<op::complex::kUnitaryC2R, op::complex::abs_square,
SrcExp, DType, (e1 | type::kMapper)>
complex_abs_square(const Exp<SrcExp, DType, e1> &src) {
return ComplexF<op::complex::kUnitaryC2R, op::complex::abs_square>(src);
}
template<typename SrcExp, typename DType, int e1>
inline ComplexUnitaryExp<op::complex::kUnitaryC2R, op::complex::sum_real_imag,
SrcExp, DType, (e1 | type::kMapper)>
complex_sum_real_imag(const Exp<SrcExp, DType, e1> &src) {
return ComplexF<op::complex::kUnitaryC2R, op::complex::sum_real_imag>(src);
}
template<int dim, int calctype, typename OP, typename TA, typename TB,
typename DType, int etype>
struct ShapeCheck<dim, ComplexBinaryMapExp<calctype, OP, TA, TB, DType, etype> > {
inline static Shape<dim>
Check(const ComplexBinaryMapExp<calctype, OP, TA, TB, DType, etype> &t) {
Shape<dim> shape1 = ShapeCheck<dim, TA>::Check(t.lhs_);
Shape<dim> shape2 = ShapeCheck<dim, TB>::Check(t.rhs_);
if (shape1[0] == 0) return shape2;
if (shape2[0] == 0) return shape1;
if (calctype == op::complex::kBinaryCC) {
CHECK_EQ(shape1, shape2) << "ComplexBinaryMapExp (CC): Shapes of operands are not the same.";
CHECK_EQ(shape1[dim - 1] % 2, 0) <<
"ComplexBinaryMapExp (CC): Shape of the last dimension is not even. "
"We must have real part + imaginary part.";
return shape1;
} else if (calctype == op::complex::kBinaryCR) {
for (int i = 0; i < dim - 1; ++i) {
CHECK_EQ(shape1.shape_[i], shape2.shape_[i]) <<
"ComplexBinaryMapExp (CR): Shapes of operands are not the same.";
}
CHECK_EQ(shape1[dim - 1], shape2[dim - 1] * 2) <<
"ComplexBinaryMapExp (CR): Shapes of operands do not match.";
return shape1;
} else if (calctype == op::complex::kBinaryRC) {
for (int i = 0; i < dim - 1; ++i) {
CHECK_EQ(shape1.shape_[i], shape2.shape_[i]) <<
"ComplexBinaryMapExp (RC): Shapes of operands are not the same.";
}
CHECK_EQ(shape2[dim - 1], shape1[dim - 1] * 2) <<
"ComplexBinaryMapExp (RC): Shapes of operands do not match.";
return shape2;
} else {
LOG(FATAL) << "ComplexBinaryMapExp: Unexpected Calculation Type!";
return shape1;
}
}
};
template<int dim, int calctype, typename OP, typename TA, typename DType, int etype>
struct ShapeCheck<dim, ComplexUnitaryExp<calctype, OP, TA, DType, etype> > {
inline static Shape<dim> Check(const ComplexUnitaryExp<calctype, OP, TA, DType, etype> &t) {
Shape<dim> s = ShapeCheck<dim, TA>::Check(t.src_);
CHECK_EQ(s[dim - 1] % 2, 0) << "ComplexUnitaryExp: Shape of the last dimension is not even. "
"We must have real + imaginary.";
if (calctype == op::complex::kUnitaryC2C) {
return s;
} else if (calctype == op::complex::kUnitaryC2R) {
Shape<dim> s_ret = s;
s_ret[dim - 1] /= 2;
return s_ret;
} else {
LOG(FATAL) << "ComplexUnitaryExp: Unexpected Calculation Type!";
return s;
}
}
};
// complex binary expression (cc)
template<typename OP, typename TA, typename TB, int etype, typename DType>
class Plan<ComplexBinaryMapExp<op::complex::kBinaryCC, OP, TA, TB, DType, etype>, DType> {
public:
explicit Plan(const Plan<TA, DType> &lhs, const Plan<TB, DType> &rhs)
: lhs_(lhs), rhs_(rhs) {}
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const {
const index_t base_x = static_cast<index_t>(x / 2) * 2;
if (x % 2 == 0) {
return OP::RealMap(lhs_.Eval(y, base_x), lhs_.Eval(y, base_x + 1),
rhs_.Eval(y, base_x), rhs_.Eval(y, base_x + 1));
} else {
return OP::ImagMap(lhs_.Eval(y, base_x), lhs_.Eval(y, base_x + 1),
rhs_.Eval(y, base_x), rhs_.Eval(y, base_x + 1));
}
}
private:
Plan<TA, DType> lhs_;
Plan<TB, DType> rhs_;
};
// complex binary expression (cr)
template<typename OP, typename TA, typename TB, int etype, typename DType>
class Plan<ComplexBinaryMapExp<op::complex::kBinaryCR, OP, TA, TB, DType, etype>, DType> {
public:
explicit Plan(const Plan<TA, DType> &lhs, const Plan<TB, DType> &rhs)
: lhs_(lhs), rhs_(rhs) {}
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const {
const index_t base_x = static_cast<index_t>(x / 2) * 2;
if (x % 2 == 0) {
return OP::RealMap(lhs_.Eval(y, base_x), lhs_.Eval(y, base_x + 1),
rhs_.Eval(y, base_x / 2), static_cast<DType>(0));
} else {
return OP::ImagMap(lhs_.Eval(y, base_x), lhs_.Eval(y, base_x + 1),
rhs_.Eval(y, base_x / 2), static_cast<DType>(0));
}
}
private:
Plan<TA, DType> lhs_;
Plan<TB, DType> rhs_;
};
// complex binary expression (rc)
template<typename OP, typename TA, typename TB, int etype, typename DType>
class Plan<ComplexBinaryMapExp<op::complex::kBinaryRC, OP, TA, TB, DType, etype>, DType> {
public:
explicit Plan(const Plan<TA, DType> &lhs, const Plan<TB, DType> &rhs)
: lhs_(lhs), rhs_(rhs) {}
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const {
const index_t base_x = static_cast<index_t>(x / 2) * 2;
if (x % 2 == 0) {
return OP::RealMap(lhs_.Eval(y, base_x / 2), static_cast<DType>(0),
rhs_.Eval(y, base_x), rhs_.Eval(y, base_x + 1));
} else {
return OP::ImagMap(lhs_.Eval(y, base_x / 2), static_cast<DType>(0),
rhs_.Eval(y, base_x), rhs_.Eval(y, base_x + 1));
}
}
private:
Plan<TA, DType> lhs_;
Plan<TB, DType> rhs_;
};
// complex unitary expression (c2c)
template<typename OP, typename TA, int etype, typename DType>
class Plan<ComplexUnitaryExp<op::complex::kUnitaryC2C, OP, TA, DType, etype>, DType> {
public:
explicit Plan(const Plan<TA, DType> &src) : src_(src) {}
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const {
const index_t base_x = static_cast<index_t>(x / 2) * 2;
if (0 == x % 2) {
return OP::RealMap(src_, y, base_x, y, base_x + 1);
} else {
return OP::ImagMap(src_, y, base_x, y, base_x + 1);
}
}
private:
Plan<TA, DType> src_;
};
// complex unitary expression (c2r)
template<typename OP, typename TA, int etype, typename DType>
class Plan<ComplexUnitaryExp<op::complex::kUnitaryC2R, OP, TA, DType, etype>, DType> {
public:
explicit Plan(const Plan<TA, DType> &src) : src_(src) {}
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const {
return OP::RealMap(src_, y, x * 2, y, x * 2 + 1);
}
private:
Plan<TA, DType> src_;
};
template<int calctype, typename OP, typename TA, typename TB, typename DType, int etype>
inline Plan<ComplexBinaryMapExp<calctype, OP, TA, TB, DType, etype>, DType>
MakePlan(const ComplexBinaryMapExp<calctype, OP, TA, TB, DType, etype> &e) {
return Plan<ComplexBinaryMapExp<calctype, OP, TA, TB, DType, etype>,
DType>(MakePlan(e.lhs_), MakePlan(e.rhs_));
}
template<int calctype, typename OP, typename TA, typename DType, int etype>
inline Plan<ComplexUnitaryExp<calctype, OP, TA, DType, etype>, DType>
MakePlan(const ComplexUnitaryExp<calctype, OP, TA, DType, etype> &e) {
return Plan<ComplexUnitaryExp<calctype, OP, TA, DType, etype>,
DType>(MakePlan(e.src_));
}
template<int calctype, typename OP, typename TA, typename TB, typename DType, int etype>
struct ExpInfo<ComplexBinaryMapExp<calctype, OP, TA, TB, DType, etype> > {
static const int kDimLhs = ExpInfo<TA>::kDim;
static const int kDimRhs = ExpInfo<TB>::kDim;
static const int kDim = (kDimLhs >= 0 && kDimRhs >= 0) ? \
(kDimLhs == 0 ? \
kDimRhs : \
((kDimRhs == 0 || kDimLhs == kDimRhs) ? kDimLhs : -1)) : -1;
static const int kDevMask = ExpInfo<TA>::kDevMask & ExpInfo<TB>::kDevMask;
};
template<int calctype, typename OP, typename TA, typename DType, int etype>
struct ExpInfo<ComplexUnitaryExp<calctype, OP, TA, DType, etype> > {
static const int kDim = ExpInfo<TA>::kDim;
static const int kDevMask = ExpInfo<TA>::kDevMask;
};
} // namespace expr
} // namespace mshadow
#endif // MSHADOW_EXTENSION_COMPLEX_H_
//===== EXPANDED: ../mshadow/mshadow/extension/complex.h =====
//===== EXPANDING: ../mshadow/mshadow/extension/range.h =====
/*!
* Copyright (c) 2014 by Contributors
* \file range.h
* \brief support generating a range vector
* \author Xingjian Shi
*/
#ifndef MSHADOW_EXTENSION_RANGE_H_
#define MSHADOW_EXTENSION_RANGE_H_
namespace mshadow {
namespace expr {
/*!
* \brief Generate a range vector similar to python: range(start, stop[, step][, repeat]).
If step is positive, the last element is the largest start + i * step less than stop
If step is negative, the last element is the smallest start + i * step greater than stop.
All elements are repeated for `repeat` times, e.g range(0, 4, 2, 3) --> 0, 0, 0, 2, 2, 2
* \tparam SrcExp type of lhs expression
* \tparam IndexExp type of index expression
* \tparam DType the type of elements
*/
template<typename DType>
struct RangeExp:
public Exp<RangeExp<DType>, DType, type::kMapper> {
const float start_;
const float stop_;
const float step_;
const int repeat_;
/*! \brief constructor */
RangeExp(float start, float stop, float step, int repeat)
: start_(start), stop_(stop), step_(step), repeat_(repeat) {}
};
template<typename DType>
inline RangeExp<DType>
range(float start, float stop, float step = 1, int repeat = 1) {
return RangeExp<DType>(start, stop, step, repeat);
}
//----------------------
// Execution plan
//----------------------
template<typename DType>
struct Plan<RangeExp<DType>, DType> {
public:
explicit Plan(const RangeExp<DType> &e)
: start_(e.start_),
stop_(e.stop_),
step_(e.step_),
repeat_(e.repeat_) {
}
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const {
return static_cast<DType>(start_ +
static_cast<float>((static_cast<int>(x) / repeat_)) * step_);
}
private:
const float start_;
const float stop_;
const float step_;
const int repeat_;
};
template<typename DType>
inline Plan<RangeExp<DType>, DType>
MakePlan(const RangeExp<DType> &exp) {
return Plan<RangeExp<DType>, DType>(exp);
}
template<int dim, typename DType>
struct ShapeCheck<dim, RangeExp<DType> > {
inline static Shape<dim>
Check(const RangeExp<DType> &t) {
CHECK(dim == 1)
<< "RangeExp only support 1 dimension output, received " << dim;
CHECK(t.step_ != 0)
<< "RangeExp does not support step=0, received " << t.step_;
CHECK(t.repeat_ > 0)
<< "RangeExp only supports repeat > 0, received " << t.repeat_;
if (t.step_ > 0) {
CHECK(t.start_ < t.stop_) << "RangeExp does not support (start, stop, step) = "
<< "(" << t.start_ << "," << t.stop_ << "," << t.step_ << ")";
return Shape1(t.repeat_ * ceil((t.stop_ - t.start_) / t.step_));
} else {
CHECK(t.start_ > t.stop_) << "RangeExp does not support (start, stop, step)= "
<< "(" << t.start_ << "," << t.stop_ << "," << t.step_ << ")";
return Shape1(t.repeat_ * ceil((t.stop_ - t.start_) / t.step_));
}
}
};
template<typename DType>
struct ExpInfo<RangeExp<DType> > {
static const int kDim = 1;
static const int kDevMask = 0xffff;
};
} // namespace expr
} // namespace mshadow
#endif // MSHADOW_EXTENSION_RANGE_H_
//===== EXPANDED: ../mshadow/mshadow/extension/range.h =====
//===== EXPANDING: ../mshadow/mshadow/extension/mask.h =====
/*!
* Copyright (c) 2016 by Contributors
* \file mask.h
* \brief
* \author Bing Xu
*/
#ifndef MSHADOW_EXTENSION_MASK_H_
#define MSHADOW_EXTENSION_MASK_H_
namespace mshadow {
namespace expr {
/*! \brief Broadcast a mask and do element-wise multiplication
* \tparam IndexExp type of index expression
* \tparam SrcExp type of src expression
* \tparam DType data type
*/
template<typename IndexExp, typename SrcExp, typename DType>
struct MaskExp: public Exp<MaskExp<IndexExp, SrcExp, DType>,
DType, type::kChainer> {
/*! \brief index oprand */
const IndexExp &index_;
/*! \brief matrix oprand */
const SrcExp &src_;
/*! constructor */
MaskExp(const IndexExp &index, const SrcExp &src)
: index_(index), src_(src) {}
}; // struct MaskExp
template<typename IndexExp,
typename SrcExp,
typename DType,
int e1, int e2>
inline MaskExp<IndexExp, SrcExp, DType>
mask(const Exp<IndexExp, DType, e1> &index,
const Exp<SrcExp, DType, e2> &src) {
return MaskExp<IndexExp, SrcExp, DType>(index.self(), src.self());
}
//----------------------
// Execution plan
//----------------------
template<typename IndexExp, typename SrcExp, typename DType>
struct Plan<MaskExp<IndexExp, SrcExp, DType>, DType> {
public:
explicit Plan(const MaskExp<IndexExp, SrcExp, DType> &e)
: index_(MakePlan(e.index_)), src_(MakePlan(e.src_)) {
}
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const {
return static_cast<DType>(src_.Eval(y, x) * index_.Eval(0, y));
}
private:
expr::Plan<IndexExp, DType> index_;
expr::Plan<SrcExp, DType> src_;
}; // struct Plan
template<typename IndexExp, typename SrcExp, typename DType>
inline Plan<MaskExp<IndexExp, SrcExp, DType>, DType>
MakePlan(const MaskExp<IndexExp, SrcExp, DType> &exp) {
return Plan<MaskExp<IndexExp, SrcExp, DType>, DType>(exp);
}
template<int dim, typename IndexExp, typename SrcExp, typename DType>
struct ShapeCheck<dim, MaskExp<IndexExp, SrcExp, DType> > {
inline static Shape<dim>
Check(const MaskExp<IndexExp, SrcExp, DType> &t) {
CHECK(dim == 2)
<< "MaskExp only support 2D output";
Shape<1> dshape = ShapeCheck<1, IndexExp>::Check(t.index_);
Shape<2> wshape = ShapeCheck<2, SrcExp>::Check(t.src_);
CHECK_EQ(dshape[0], wshape[0]) << "MaskExp require inputs in same first dimention";
Shape<dim> ret;
ret[0] = wshape[0];
ret[1] = wshape[1];
return ret;
}
};
template<typename IndexExp, typename SrcExp, typename DType>
struct ExpInfo<MaskExp<IndexExp, SrcExp, DType> > {
static const int kDim = 2;
static const int kDevMask = ExpInfo<IndexExp>::kDevMask;
};
} // namespace expr
} // namespace mshadow
#endif // MSHADOW_EXTENSION_MASK_H_
//===== EXPANDED: ../mshadow/mshadow/extension/mask.h =====
#endif // MSHADOW_EXTENSION_H_
//===== EXPANDED: ../mshadow/mshadow/extension.h =====
//===== EXPANDING: ../mshadow/mshadow/tensor_cpu-inl.h =====
/*!
* Copyright (c) 2014 by Contributors
* \file tensor_cpu-inl.h
* \brief implementation of CPU host code
* \author Bing Xu, Tianqi Chen
*/
#ifndef MSHADOW_TENSOR_CPU_INL_H_
#define MSHADOW_TENSOR_CPU_INL_H_
namespace mshadow {
template<>
inline void InitTensorEngine<cpu>(int dev_id) {
}
template<>
inline void ShutdownTensorEngine<cpu>(void) {
}
template<>
inline void SetDevice<cpu>(int devid) {
}
template<>
inline Stream<cpu> *NewStream<cpu>(bool create_blas_handle,
bool create_dnn_handle) {
return new Stream<cpu>();
}
template<>
inline void DeleteStream<cpu>(Stream<cpu> *stream) {
delete stream;
}
template<int ndim>
inline std::ostream &operator<<(std::ostream &os, const Shape<ndim> &shape) { // NOLINT(*)
os << '(';
for (int i = 0; i < ndim; ++i) {
if (i != 0) os << ',';
os << shape[i];
}
// python style tuple
if (ndim == 1) os << ',';
os << ')';
return os;
}
template<typename xpu>
inline void *AllocHost_(size_t size);
template<typename xpu>
inline void FreeHost_(void * dptr);
#ifdef __CUDACC__
template<>
inline void *AllocHost_<gpu>(size_t size) {
void *dptr;
MSHADOW_CUDA_CALL(cudaMallocHost(&dptr, size, cudaHostAllocPortable));
return dptr;
}
template<>
inline void FreeHost_<gpu>(void *dptr) {
MSHADOW_CUDA_CALL(cudaFreeHost(dptr));
}
#endif
template<>
inline void *AllocHost_<cpu>(size_t size) {
size_t pitch;
return packet::AlignedMallocPitch(&pitch, size, 1);
}
template<>
inline void FreeHost_<cpu>(void *dptr) {
packet::AlignedFree(dptr);
}
template<typename xpu, int dim, typename DType>
inline void AllocHost(Tensor<cpu, dim, DType> *obj) {
obj->stride_ = obj->size(dim - 1);
CHECK_EQ(obj->CheckContiguous(), true) << "AllocHost";
void *dptr = AllocHost_<xpu>(obj->MSize() * sizeof(DType));
obj->dptr_ = reinterpret_cast<DType*>(dptr);
}
template<typename xpu, int dim, typename DType>
inline void FreeHost(Tensor<cpu, dim, DType> *obj) {
if (obj->dptr_ == NULL) {
LOG(FATAL) << "FreeHost:: double free";
}
FreeHost_<xpu>(obj->dptr_);
obj->dptr_ = NULL;
}
template<int dim, typename DType>
inline void AllocSpace(Tensor<cpu, dim, DType> *obj, bool pad) {
size_t pitch;
void *dptr;
if (pad) {
dptr = packet::AlignedMallocPitch
(&pitch, obj->size(dim - 1) * sizeof(DType), obj->shape_.FlatTo2D()[0]);
obj->stride_ = static_cast<index_t>(pitch / sizeof(DType));
} else {
obj->stride_ = obj->size(dim - 1);
dptr = packet::AlignedMallocPitch
(&pitch, obj->shape_.Size() * sizeof(DType), 1);
}
obj->dptr_ = reinterpret_cast<DType*>(dptr);
}
template<typename Device, typename DType, int dim>
inline Tensor<Device, dim, DType>
NewTensor(const Shape<dim> &shape, DType initv, bool pad, Stream<Device> *stream_) {
Tensor<Device, dim, DType> obj(shape);
obj.stream_ = stream_;
AllocSpace(&obj, pad);
MapExp<sv::saveto>(&obj, expr::ScalarExp<DType>(initv));
return obj;
}
template<int dim, typename DType>
inline void FreeSpace(Tensor<cpu, dim, DType> *obj) {
packet::AlignedFree(obj->dptr_);
obj->dptr_ = NULL;
}
template<int dim, typename DType>
inline void Copy(Tensor<cpu, dim, DType> _dst,
const Tensor<cpu, dim, DType> &_src,
Stream<cpu> *stream) {
CHECK_EQ(_dst.shape_, _src.shape_)
<< "Copy:shape mismatch:" << _dst.shape_ << " vs " << _src.shape_;
if (_dst.CheckContiguous() && _src.CheckContiguous()) {
memcpy(_dst.dptr_, _src.dptr_, sizeof(DType) * _dst.shape_.Size());
} else {
Tensor<cpu, 2, DType> dst = _dst.FlatTo2D();
Tensor<cpu, 2, DType> src = _src.FlatTo2D();
for (index_t y = 0; y < dst.size(0); ++y) {
memcpy(dst[y].dptr_, src[y].dptr_, sizeof(DType) * dst.size(1));
}
}
}
template<typename Saver, typename R, int dim,
typename DType, typename E>
inline void MapPlan(TRValue<R, cpu, dim, DType> *dst,
const expr::Plan<E, DType> &plan) {
Shape<2> shape = expr::ShapeCheck<dim, R>::Check(dst->self()).FlatTo2D();
expr::Plan<R, DType> dplan = expr::MakePlan(dst->self());
#if (MSHADOW_USE_CUDA == 0)
#pragma omp parallel for
#endif
// temp remove openmp, as default setting throttles CPU
for (openmp_index_t y = 0; y < shape[0]; ++y) {
for (index_t x = 0; x < shape[1]; ++x) {
// trust your compiler! -_- they will optimize it
Saver::template Save<DType>(dplan.REval(y, x), plan.Eval(y, x));
}
}
}
// code to handle SSE optimization
template<bool pass_check, typename Saver,
typename R, int dim,
typename DType, typename E, int etype>
struct MapExpCPUEngine {
inline static void Map(TRValue<R, cpu, dim, DType> *dst,
const expr::Exp<E, DType, etype> &exp) {
MapPlan<Saver>(dst, MakePlan(exp.self()));
}
};
template<typename SV, int dim, typename DType, typename E, int etype>
struct MapExpCPUEngine<true, SV, Tensor<cpu, dim, DType>,
dim, DType, E, etype> {
inline static void Map(Tensor<cpu, dim, DType> *dst,
const expr::Exp<E, DType, etype> &exp) {
if (expr::PacketAlignCheck<dim, E, MSHADOW_DEFAULT_PACKET>::Check(exp.self()) &&
expr::PacketAlignCheck<dim, Tensor<cpu, dim, DType>, MSHADOW_DEFAULT_PACKET>::Check(*dst)) {
expr::MapPacketPlan<SV>(dst->self(),
expr::MakePacketPlan<MSHADOW_DEFAULT_PACKET>(exp.self()));
} else {
MapPlan<SV>(dst, MakePlan(exp.self()));
}
}
};
template<typename Saver, typename R, int dim,
typename DType, typename E, int etype>
inline void MapExp(TRValue<R, cpu, dim, DType> *dst,
const expr::Exp<E, DType, etype> &exp) {
expr::TypeCheckPass<expr::TypeCheck<cpu, dim, DType, E>::kMapPass>
::Error_All_Tensor_in_Exp_Must_Have_Same_Type();
Shape<dim> eshape = expr::ShapeCheck<dim, E>::Check(exp.self());
Shape<dim> dshape = expr::ShapeCheck<dim, R>::Check(dst->self());
CHECK(eshape[0] == 0 || eshape == dshape)
<< "Assignment: Shape of Tensors are not consistent with target, "
<< "eshape: " << eshape << " dshape:" << dshape;
MapExpCPUEngine<expr::PacketCheck<E, MSHADOW_DEFAULT_PACKET>::kPass,
Saver, R, dim, DType, E, etype>
::Map(dst->ptrself(), exp);
}
template<typename Saver, typename Reducer,
typename R, typename DType, typename E, int etype>
inline void MapReduceKeepLowest(TRValue<R, cpu, 1, DType> *dst,
const expr::Exp<E, DType, etype> &exp,
DType scale) {
expr::TypeCheckPass<expr::TypeCheck<cpu, 1, DType, E>::kRedPass>
::Error_TypeCheck_Not_Pass_For_Reduce_Exp();
Shape<2> eshape = expr::ShapeCheck<expr::ExpInfo<E>::kDim, E>
::Check(exp.self()).FlatTo2D();
Shape<1> dshape = expr::ShapeCheck<1, R>::Check(dst->self());
CHECK_EQ(eshape[1], dshape[0]) << "MapReduceKeepLowest::reduction dimension do not match";
CHECK_NE(eshape[0], 0) << "can not reduce over empty tensor";
// execution
expr::Plan<R, DType> dplan = MakePlan(dst->self());
expr::Plan<E, DType> splan = MakePlan(exp.self());
#if (MSHADOW_USE_CUDA == 0)
#pragma omp parallel for
#endif
for (openmp_index_t x = 0; x < eshape[1]; ++x) {
DType res = splan.Eval(0, x);
for (index_t y = 1; y < eshape[0]; ++y) {
Reducer::Reduce(res, splan.Eval(y, x));
}
Saver::template Save<DType>(dplan.REval(0, x), res * scale);
}
}
template<typename Saver, typename Reducer, int dimkeep,
typename R, typename DType, typename E, int etype>
inline void MapReduceKeepHighDim(TRValue<R, cpu, 1, DType> *dst,
const expr::Exp<E, DType, etype> &exp,
DType scale) {
expr::TypeCheckPass<expr::TypeCheck<cpu, dimkeep, DType, E>::kRedPass>
::Error_TypeCheck_Not_Pass_For_Reduce_Exp();
typedef Shape<expr::ExpInfo<E>::kDim> EShape;
EShape eshape = expr::ShapeCheck<expr::ExpInfo<E>::kDim, E>
::Check(exp.self());
Shape<1> dshape = expr::ShapeCheck<1, R>::Check(dst->self());
CHECK_EQ(eshape[dimkeep], dshape[0])
<< "MapReduceKeepHighDim::reduction dimension do not match";
// use equvalent form
Shape<4> pshape = Shape4(eshape.ProdShape(0, dimkeep),
eshape[dimkeep],
eshape.ProdShape(dimkeep + 1, EShape::kSubdim),
eshape[EShape::kSubdim]);
// execution
expr::Plan<R, DType> dplan = MakePlan(dst->self());
expr::Plan<E, DType> splan = MakePlan(exp.self());
#if (MSHADOW_USE_CUDA == 0)
#pragma omp parallel for
#endif
for (openmp_index_t c = 0; c < pshape[1]; ++c) {
DType res; Reducer::SetInitValue(res);
for (index_t n = 0; n < pshape[0]; ++n) {
DType tres; Reducer::SetInitValue(tres);
for (index_t y = 0; y < pshape[2]; ++y) {
for (index_t x = 0; x < pshape[3]; ++x) {
Reducer::Reduce(tres,
splan.Eval((n * pshape[1] + c) * pshape[2] + y, x));
}
}
Reducer::Reduce(res, tres);
}
Saver::template Save<DType>(dplan.REval(0, c), DType(res * scale));
}
}
template<typename DType>
inline void Softmax(Tensor<cpu, 1, DType> dst,
const Tensor<cpu, 1, DType> &energy) {
DType mmax = energy[0];
for (index_t x = 1; x < dst.size(0); ++x) {
if (mmax < energy[x]) mmax = energy[x];
}
DType sum = DType(0.0f);
for (index_t x = 0; x < dst.size(0); ++x) {
dst[x] = std::exp(energy[x] - mmax);
sum += dst[x];
}
for (index_t x = 0; x < dst.size(0); ++x) {
dst[x] /= sum;
}
}
template<typename DType>
inline void SoftmaxGrad(Tensor<cpu, 2, DType> dst,
const Tensor<cpu, 2, DType> &src,
const Tensor<cpu, 1, DType> &label) {
#pragma omp parallel for
for (openmp_index_t y = 0; y < dst.size(0); ++y) {
const index_t k = static_cast<int>(label[y]);
for (index_t x = 0; x < dst.size(1); ++x) {
if (x == k) {
dst[y][k] = src[y][k] - 1.0f;
} else {
dst[y][x] = src[y][x];
}
}
}
}
template<typename DType>
inline void SoftmaxGrad(Tensor<cpu, 2, DType> dst,
const Tensor<cpu, 2, DType> &src,
const Tensor<cpu, 1, DType> &label,
const DType &ignore_label) {
#pragma omp parallel for
for (openmp_index_t y = 0; y < dst.size(0); ++y) {
const index_t k = static_cast<int>(label[y]);
for (index_t x = 0; x < dst.size(1); ++x) {
if (static_cast<int>(ignore_label) == k) {
dst[y][x] = 0.0f;
} else {
if (x == k) {
dst[y][k] = src[y][k] - 1.0f;
} else {
dst[y][x] = src[y][x];
}
}
}
}
}
template<typename DType>
inline void SoftmaxGrad(Tensor<cpu, 3, DType> dst,
const Tensor<cpu, 3, DType> &src,
const Tensor<cpu, 2, DType> &label) {
#pragma omp parallel for
for (openmp_index_t n = 0; n < dst.size(2); ++n) {
for (index_t y = 0; y < dst.size(0); ++y) {
const index_t k = static_cast<int>(label[y][n]);
for (index_t x = 0; x < dst.size(1); ++x) {
if (x == k) {
dst[y][k][n] = src[y][k][n] - 1.0f;
} else {
dst[y][x][n] = src[y][x][n];
}
}
}
}
}
template<typename DType>
inline void SoftmaxGrad(Tensor<cpu, 3, DType> dst,
const Tensor<cpu, 3, DType> &src,
const Tensor<cpu, 2, DType> &label,
const DType &ignore_label) {
#pragma omp parallel for
for (openmp_index_t n = 0; n < dst.size(2); ++n) {
for (index_t y = 0; y < dst.size(0); ++y) {
const index_t k = static_cast<int>(label[y][n]);
if (k == static_cast<int>(ignore_label)) {
for (index_t x = 0; x < dst.size(1); ++x) {
dst[y][x][n] = DType(0.0f);
}
} else {
for (index_t x = 0; x < dst.size(1); ++x) {
if (x == k) {
dst[y][k][n] = src[y][k][n] - 1.0f;
} else {
dst[y][x][n] = src[y][x][n];
}
}
}
}
}
}
template<typename DType>
inline void Softmax(Tensor<cpu, 2, DType> dst,
const Tensor<cpu, 2, DType> &energy) {
CHECK_EQ(dst.shape_, energy.shape_) << "Softmax: shape mismatch";
#pragma omp parallel for
for (openmp_index_t y = 0; y < dst.size(0); ++y) {
Softmax(dst[y], energy[y]);
}
}
template<typename DType>
inline void Softmax(Tensor<cpu, 3, DType> dst,
const Tensor<cpu, 3, DType> &energy) {
CHECK_EQ(dst.shape_, energy.shape_) << "Softmax: shape mismatch";
#pragma omp parallel for
for (openmp_index_t y = 0; y < dst.size(0); ++y) {
for (index_t n = 0; n < dst.size(2); ++n) {
DType mmax = energy[y][0][n];
for (index_t x = 1; x < dst.size(1); ++x) {
if (mmax < energy[y][x][n]) mmax = energy[y][x][n];
}
DType sum = DType(0.0f);
for (index_t x = 0; x < dst.size(1); ++x) {
dst[y][x][n] = std::exp(energy[y][x][n] - mmax);
sum += dst[y][x][n];
}
for (index_t x = 0; x < dst.size(1); ++x) {
dst[y][x][n] /= sum;
}
}
}
}
template<typename IndexType, typename DType>
inline void AddTakeGrad(Tensor<cpu, 2, DType> dst,
const Tensor<cpu, 1, IndexType>& index,
const Tensor<cpu, 2, DType> &src) {
for (index_t y = 0; y < index.size(0); ++y) {
dst[index[y]] += src[y];
}
}
template<typename IndexType, typename DType>
inline void AddTakeGradLargeBatch(Tensor<cpu, 2, DType> dst,
const Tensor<cpu, 1, IndexType>& sorted,
const Tensor<cpu, 1, IndexType>& index,
const Tensor<cpu, 2, DType> &src) {
for (index_t y = 0; y < sorted.size(0); ++y) {
dst[sorted[y]] += src[index[y]];
}
}
template<typename IndexType, typename DType>
inline void IndexFill(Tensor<cpu, 2, DType> dst,
const Tensor<cpu, 1, IndexType>& index,
const Tensor<cpu, 2, DType> &src) {
for (index_t y = 0; y < index.size(0); ++y) {
for (index_t j = 0; j < src.size(1); j++) {
dst[index[y]][j] = src[y][j];
}
}
}
template<typename KDType, typename VDType>
inline void SortByKey(Tensor<cpu, 1, KDType> keys, Tensor<cpu, 1, VDType> values,
bool is_ascend) {
CHECK_EQ(keys.CheckContiguous(), true);
CHECK_EQ(values.CheckContiguous(), true);
CHECK_EQ(keys.size(0), values.size(0))
<< "The sizes of key/value are not equal! keys_size: " << keys.size(0)
<< "values_size: " << values.size(0);
std::vector<size_t> idx(keys.size(0));
std::vector<KDType> keys_vec(keys.size(0));
std::vector<VDType> values_vec(values.size(0));
for (int i = 0; i < keys.size(0); i++) {
idx[i] = i;
keys_vec[i] = keys[i];
values_vec[i] = values[i];
}
if (is_ascend) {
std::stable_sort(idx.begin(), idx.end(),
[&keys_vec](size_t i1, size_t i2)
{return keys_vec[i1] < keys_vec[i2]; });
} else {
std::stable_sort(idx.begin(), idx.end(),
[&keys_vec](size_t i1, size_t i2)
{return keys_vec[i1] > keys_vec[i2]; });
}
for (index_t i = 0; i < values.size(0); i++) {
keys[i] = keys_vec[idx[i]];
values[i] = values_vec[idx[i]];
}
}
template<typename Device, typename VDType, typename SDType>
inline void VectorizedSort(Tensor<Device, 1, VDType> values, Tensor<Device, 1, SDType> segments) {
// We can sort each segments using two stable sorts
SortByKey(values, segments, true);
SortByKey(segments, values, true);
}
// blas related
template<typename Device, typename DType>
inline void VectorDot(Tensor<Device, 1, DType> dst,
const Tensor<Device, 1, DType> &lhs,
const Tensor<Device, 1, DType> &rhs) {
CHECK_EQ(lhs.size(0), rhs.size(0))
<< "VectorDot: Shape mismatch";
CHECK_EQ(dst.size(0), 1)
<< "VectorDot: expect dst to be scalar";
expr::BLASEngine<Device, DType>::SetStream(lhs.stream_);
mshadow::expr::BLASEngine<Device, DType>::dot(
lhs.stream_, lhs.size(0), lhs.dptr_, 1, rhs.dptr_, 1, dst.dptr_);
}
template<bool transpose_left, bool transpose_right, typename Device, typename DType>
inline void BatchGEMM(Tensor<Device, 3, DType> dst,
const Tensor<Device, 3, DType> &lhs,
const Tensor<Device, 3, DType> &rhs,
DType alpha,
DType beta,
Tensor<Device, 1, DType*> workspace) {
index_t batch_size = dst.shape_[0];
expr::BLASEngine<Device, DType>::SetStream(dst.stream_);
Shape<3> sleft = transpose_left ? Shape3(lhs.shape_[0], lhs.shape_[2], lhs.shape_[1])
: lhs.shape_;
Shape<3> sright = transpose_right ? Shape3(rhs.shape_[0], rhs.shape_[2], rhs.shape_[1])
: rhs.shape_;
CHECK_EQ(dst.CheckContiguous(), true);
CHECK_EQ(lhs.CheckContiguous(), true);
CHECK_EQ(rhs.CheckContiguous(), true);
CHECK(sleft[0] == batch_size && sright[0] == batch_size)
<< "BatchGEMM: batchsize must be equal."
<< "dst: " << dst.shape_ << "\n"
<< "lhs: " << sleft << "\n"
<< "rhs: " << sright << "\n";
CHECK(dst.size(1) == sleft[1] && dst.size(2) == sright[2] && sleft[2] == sright[1])
<< "BatchGEMM: matrix shape mismatch"
<< "dst: " << dst.shape_ << "\n"
<< "lhs: " << sleft << "\n"
<< "rhs: " << sright << "\n";
CHECK(workspace.size(0) >= 3 * batch_size)
<< "Workspace Size must be bigger than " << 3 * batch_size;
CHECK_EQ(workspace.CheckContiguous(), true);
// use column major argument to compatible with most BLAS
expr::BLASEngine<Device, DType>::batched_gemm
(dst.stream_,
transpose_right, transpose_left,
transpose_right ? rhs.size(1) : rhs.size(2),
transpose_left ? lhs.size(2) : lhs.size(1),
transpose_right ? rhs.size(2) : rhs.size(1),
alpha,
rhs.dptr_, rhs.stride_,
lhs.dptr_, lhs.stride_,
beta,
dst.dptr_, dst.stride_, batch_size,
workspace.dptr_);
}
} // namespace mshadow
#endif // MSHADOW_TENSOR_CPU_INL_H_
//===== EXPANDED: ../mshadow/mshadow/tensor_cpu-inl.h =====
//===== EXPANDING: ../mshadow/mshadow/tensor_gpu-inl.h =====
/*!
* Copyright (c) 2014 by Contributors
* \file tensor_gpu-inl.h
* \brief implementation of GPU host code
* \author Bing Xu, Tianqi Chen
*/
#ifndef MSHADOW_TENSOR_GPU_INL_H_
#define MSHADOW_TENSOR_GPU_INL_H_
namespace mshadow {
#if MSHADOW_USE_CUDA
template<>
inline void InitTensorEngine<gpu>(int dev_id) {
cudaDeviceProp prop;
int device_id = 0;
int device_count = 0;
cudaGetDeviceCount(&device_count);
CHECK_GT(device_count, 0) << "Cannot find CUDA device. Please check CUDA-Configuration";
if (dev_id < 0) {
device_id = 0;
} else {
device_id = dev_id;
}
CHECK_LT(device_id, device_count) << "Incorrect Device ID";
MSHADOW_CUDA_CALL(cudaSetDevice(device_id));
MSHADOW_CUDA_CALL(cudaGetDeviceProperties(&prop, device_id));
}
template<>
inline void ShutdownTensorEngine<gpu>(void) {
}
template<>
inline void SetDevice<gpu>(int devid) {
MSHADOW_CUDA_CALL(cudaSetDevice(devid));
}
template<int dim, typename DType>
inline void AllocSpace(Tensor<gpu, dim, DType> *obj, bool pad) {
size_t pitch;
// common choice for cuda mem align unit is 32
if (pad && obj->size(dim - 1) >= MSHADOW_MIN_PAD_RATIO * 32) {
MSHADOW_CUDA_CALL(cudaMallocPitch(reinterpret_cast<void**>(&(obj->dptr_)), &pitch,
obj->size(dim - 1) * sizeof(DType),
obj->shape_.FlatTo2D()[0]));
obj->stride_ = static_cast<index_t>(pitch / sizeof(DType));
} else {
obj->stride_ = obj->size(dim - 1);
MSHADOW_CUDA_CALL(cudaMallocPitch(reinterpret_cast<void**>(&(obj->dptr_)), &pitch,
obj->shape_.Size() * sizeof(DType), 1));
}
}
template<int dim, typename DType>
inline void FreeSpace(Tensor<gpu, dim, DType> *obj) {
MSHADOW_CUDA_CALL(cudaFree(obj->dptr_));
obj->dptr_ = NULL;
}
template<typename A, typename B, int dim, typename DType>
inline void Copy(Tensor<A, dim, DType> _dst,
Tensor<B, dim, DType> _src,
cudaMemcpyKind kind,
Stream<gpu> *stream) {
CHECK_EQ(_dst.shape_, _src.shape_) << "Copy:shape mismatch";
Tensor<A, 2, DType> dst = _dst.FlatTo2D();
Tensor<B, 2, DType> src = _src.FlatTo2D();
MSHADOW_CUDA_CALL(cudaMemcpy2DAsync(dst.dptr_, dst.stride_ * sizeof(DType),
src.dptr_, src.stride_ * sizeof(DType),
dst.size(1) * sizeof(DType),
dst.size(0), kind,
Stream<gpu>::GetStream(stream)));
// use synchronize call behavior for zero stream
if (stream == NULL) {
MSHADOW_CUDA_CALL(cudaStreamSynchronize(0));
}
}
template<int dim, typename DType>
inline void Copy(Tensor<cpu, dim, DType> dst,
const Tensor<gpu, dim, DType> &src,
Stream<gpu> *stream) {
Copy(dst, src, cudaMemcpyDeviceToHost, stream);
}
template<int dim, typename DType>
inline void Copy(Tensor<gpu, dim, DType> dst,
const Tensor<gpu, dim, DType> &src,
Stream<gpu> *stream) {
Copy(dst, src, cudaMemcpyDeviceToDevice, stream);
}
template<int dim, typename DType>
inline void Copy(Tensor<gpu, dim, DType> dst,
const Tensor<cpu, dim, DType> &src,
Stream<gpu> *stream) {
Copy(dst, src, cudaMemcpyHostToDevice, stream);
}
#endif // MSHADOW_USE_CUDA
} // namespace mshadow
// the following part is included only if compiler is nvcc
#ifdef __CUDACC__
namespace mshadow {
template<typename Saver, typename R, int dim,
typename DType, typename E, int etype>
inline void MapExp(TRValue<R, gpu, dim, DType> *dst,
const expr::Exp<E, DType, etype> &exp) {
expr::TypeCheckPass<expr::TypeCheck<gpu, dim, DType, E>::kMapPass>
::Error_All_Tensor_in_Exp_Must_Have_Same_Type();
Shape<dim> eshape = expr::ShapeCheck<dim, E>::Check(exp.self());
Shape<dim> dshape = expr::ShapeCheck<dim, R>::Check(dst->self());
CHECK(eshape[0] == 0 || eshape == dshape)
<< "Assignment: Shape of Tensors are not consistent with target, "
<< "eshape: " << eshape << " dshape:" << dshape;
cuda::MapPlan<Saver>(MakePlan(dst->self()),
MakePlan(exp.self()),
dshape.FlatTo2D(),
Stream<gpu>::GetStream(expr::StreamInfo<gpu, R>::Get(dst->self())));
}
template<typename Saver, typename Reducer,
typename R, typename DType, typename E, int etype>
inline void MapReduceKeepLowest(TRValue<R, gpu, 1, DType> *dst,
const expr::Exp<E, DType, etype> &exp,
DType scale) {
expr::TypeCheckPass<expr::TypeCheck<gpu, 1, DType, E>::kRedPass>
::Error_TypeCheck_Not_Pass_For_Reduce_Exp();
Shape<2> eshape = expr::ShapeCheck<expr::ExpInfo<E>::kDim, E>
::Check(exp.self()).FlatTo2D();
Shape<1> dshape = expr::ShapeCheck<1, R>::Check(dst->self());
CHECK_EQ(eshape[1], dshape[0]) << "MapReduceKeepLowest::reduction dimension do not match";
CHECK_NE(eshape[0], 0) << "can not reduce over empty tensor";
cuda::MapReduceKeepLowest<Saver, Reducer>
(MakePlan(dst->self()), MakePlan(exp.self()), scale, eshape,
Stream<gpu>::GetStream(expr::StreamInfo<gpu, R>::Get(dst->self())));
}
template<typename Saver, typename Reducer, int dimkeep,
typename R, typename DType, typename E, int etype>
inline void MapReduceKeepHighDim(TRValue<R, gpu, 1, DType> *dst,
const expr::Exp<E, DType, etype> &exp,
DType scale) {
expr::TypeCheckPass<expr::TypeCheck<gpu, dimkeep, DType, E>::kRedPass>
::Error_TypeCheck_Not_Pass_For_Reduce_Exp();
typedef Shape<expr::ExpInfo<E>::kDim> EShape;
EShape eshape = expr::ShapeCheck<expr::ExpInfo<E>::kDim, E>
::Check(exp.self());
Shape<1> dshape = expr::ShapeCheck<1, R>::Check(dst->self());
CHECK_EQ(eshape[dimkeep], dshape[0]) << "MapReduceKeepHighDim::reduction dimension do not match";
// use equvalent form
Shape<4> pshape = Shape4(eshape.ProdShape(0, dimkeep),
eshape[dimkeep],
eshape.ProdShape(dimkeep + 1, EShape::kSubdim),
eshape[EShape::kSubdim]);
// call equavalent map red dim 2
cuda::MapReduceKeepDim1<Saver, Reducer>
(MakePlan(dst->self()), MakePlan(exp.self()), scale, pshape,
Stream<gpu>::GetStream(expr::StreamInfo<gpu, R>::Get(dst->self())));
}
template<typename DType>
inline void Softmax(Tensor<gpu, 2, DType> dst,
const Tensor<gpu, 2, DType>& src) {
cuda::Softmax(dst, src);
}
template<typename DType>
inline void Softmax(Tensor<gpu, 3, DType> dst,
const Tensor<gpu, 3, DType>& src) {
cuda::Softmax(dst, src);
}
template<typename DType>
inline void SoftmaxGrad(Tensor<gpu, 2, DType> dst,
const Tensor<gpu, 2, DType> &src,
const Tensor<gpu, 1, DType> &label) {
cuda::SoftmaxGrad(dst, src, label);
}
template<typename DType>
inline void SoftmaxGrad(Tensor<gpu, 2, DType> dst,
const Tensor<gpu, 2, DType> &src,
const Tensor<gpu, 1, DType> &label,
const DType &ignore_label) {
cuda::SoftmaxGrad(dst, src, label, ignore_label);
}
template<typename DType>
inline void SoftmaxGrad(Tensor<gpu, 3, DType> dst,
const Tensor<gpu, 3, DType> &src,
const Tensor<gpu, 2, DType> &label) {
cuda::SoftmaxGrad(dst, src, label);
}
template<typename DType>
inline void SoftmaxGrad(Tensor<gpu, 3, DType> dst,
const Tensor<gpu, 3, DType> &src,
const Tensor<gpu, 2, DType> &label,
const DType &ignore_label) {
cuda::SoftmaxGrad(dst, src, label, ignore_label);
}
template<typename IndexType, typename DType>
inline void AddTakeGrad(Tensor<gpu, 2, DType> dst,
const Tensor<gpu, 1, IndexType>& index,
const Tensor<gpu, 2, DType> &src) {
cuda::AddTakeGrad(dst, index, src);
}
template<typename IndexType, typename DType>
inline void AddTakeGradLargeBatch(Tensor<gpu, 2, DType> dst,
const Tensor<gpu, 1, IndexType>& sorted,
const Tensor<gpu, 1, IndexType>& index,
const Tensor<gpu, 2, DType> &src) {
cuda::AddTakeGradLargeBatch(dst, sorted, index, src);
}
template<typename KDType, typename VDType>
inline void SortByKey(Tensor<gpu, 1, KDType> keys, Tensor<gpu, 1, VDType> values,
bool is_ascend) {
cuda::SortByKey(keys, values, is_ascend);
}
template<typename IndexType, typename DType>
inline void IndexFill(Tensor<gpu, 2, DType> dst,
const Tensor<gpu, 1, IndexType>& index,
const Tensor<gpu, 2, DType> &src) {
cuda::IndexFill(dst, index, src);
}
} // namespace mshadow
#endif // __CUDACC__
#endif // MSHADOW_TENSOR_GPU_INL_H_
//===== EXPANDED: ../mshadow/mshadow/tensor_gpu-inl.h =====
//===== EXPANDING: ../mshadow/mshadow/io.h =====
/*!
* Copyright (c) 2014 by Contributors
* \file io.h
* \brief definitions of I/O functions for mshadow tensor
* \author Tianqi Chen
*/
#ifndef MSHADOW_IO_H_
#define MSHADOW_IO_H_
namespace mshadow {
namespace utils {
/*!
* \brief interface of stream I/O, used to serialize data,
* mshadow does not restricted to only this interface in SaveBinary/LoadBinary
* mshadow accept all class that implements Read and Write
*/
class IStream {
public:
/*!
* \brief read data from stream
* \param ptr pointer to memory buffer
* \param size size of block
* \return usually is the size of data readed
*/
virtual size_t Read(void *ptr, size_t size) = 0;
/*!
* \brief write data to stream
* \param ptr pointer to memory buffer
* \param size size of block
*/
virtual void Write(const void *ptr, size_t size) = 0;
/*! \brief virtual destructor */
virtual ~IStream(void) {}
};
} // namespace utils
/*!
* \brief CPU/GPU: save a tensor by binary format, for GPU version, a temp Tensor<cpu,dim> storage will be allocated
* \param fo output binary stream
* \param src source data file
* \tparam dim dimension of tensor
* \tparam DType type of element in tensor
* \tparam TStream type of stream, need to support Read, Write, one example is utils::IStream.
*/
template<int dim, typename DType, typename TStream>
inline void SaveBinary(TStream &fo, const Tensor<cpu, dim, DType> &src); // NOLINT(*)
/*!
* \brief CPU/GPU: save a tensor by binary format, for GPU version, a temp Tensor<cpu,dim> storage will be allocated
* \param fo output binary stream
* \param src source data file
* \tparam dim dimension of tensor
* \tparam DType type of element in tensor
* \tparam TStream type of stream, need to support Read, Write, one example is utils::IStream.
*/
template<int dim, typename DType, typename TStream>
inline void SaveBinary(TStream &fo, const Tensor<gpu, dim, DType> &src); // NOLINT(*)
/*!
* \brief CPU/GPU: load a tensor by binary format, for GPU version, a temp Tensor<cpu,dim> storage will be allocated
* if pre_alloc is true , then space in dst is preallocated, and must have same shape of the tensor loaded
* if pre_alloc is false, then dst originally does not have space allocated, LoadBinary will allocate space for dst
* \param fi output binary stream
* \param dst destination file
* \param pre_alloc whether space is pre-allocated, if false, space allocation will happen
* \tparam dim dimension of tensor
* \tparam DType type of element in tensor
* \tparam TStream type of stream, need to support Read, Write, one example is utils::IStream.
*/
template<int dim, typename DType, typename TStream>
inline void LoadBinary(TStream &fi, // NOLINT(*)
Tensor<cpu, dim, DType> *dst, bool pre_alloc);
/*!
* \brief CPU/GPU: load a tensor by binary format, for GPU version, a temp Tensor<cpu,dim> storage will be allocated
* if pre_alloc is true , then space in dst is preallocated, and must have same shape of the tensor loaded
* if pre_alloc is false, then dst originally does not have space allocated, LoadBinary will allocate space for dst
* \param fi output binary stream
* \param dst destination file
* \param pre_alloc whether space is pre-allocated, if false, space allocation will happen
* \tparam dim dimension of tensor
* \tparam DType type of element in tensor
* \tparam TStream type of stream, need to support Read, Write, one example is utils::IStream.
*/
template<int dim, typename DType, typename TStream>
inline void LoadBinary(TStream &fi, // NOLINT(*)
Tensor<gpu, dim, DType> *dst, bool pre_alloc);
// implementations
template<int dim, typename DType, typename TStream>
inline void SaveBinary(TStream &fo, const Tensor<cpu, dim, DType> &src_) { // NOLINT(*)
fo.Write(&src_.shape_, sizeof(src_.shape_));
Tensor<cpu, 2, DType> src = src_.FlatTo2D();
for (index_t i = 0; i < src.size(0); ++i) {
fo.Write(src[i].dptr_, sizeof(DType) * src.size(1));
}
}
template<int dim, typename DType, typename TStream>
inline void SaveBinary(TStream &fo, const Tensor<gpu, dim, DType> &src) { // NOLINT(*)
// copy to CPU, then save
Tensor<cpu, dim, DType> tmp(src.shape_);
AllocSpace(&tmp);
Stream<gpu> stream;
Copy(tmp, src, &stream);
SaveBinary(fo, tmp);
FreeSpace(&tmp);
}
template<int dim, typename DType, typename TStream>
inline void LoadBinary(TStream &fi, // NOLINT(*)
Tensor<cpu, dim, DType> *dst_, bool pre_alloc) {
Shape<dim> shape;
CHECK_NE(fi.Read(&shape, sizeof(shape)), 0) << "mshadow::LoadBinary";
if (pre_alloc) {
CHECK_EQ(shape, dst_->shape_) << "LoadBinary, shape do not match pre-allocated shape";
} else {
dst_->shape_ = shape; AllocSpace(dst_);
}
Tensor<cpu, 2, DType> dst = dst_->FlatTo2D();
if (dst.size(0) == 0) return;
for (index_t i = 0; i < dst.size(0); ++i) {
CHECK_NE(fi.Read(dst[i].dptr_, sizeof(DType) * dst.size(1)), 0) << "mshadow::LoadBinary";
}
}
template<int dim, typename DType, typename TStream>
inline void LoadBinary(TStream &fi, // NOLINT(*)
Tensor<gpu, dim, DType> *dst, bool pre_alloc) {
Tensor<cpu, dim, DType> tmp;
LoadBinary(fi, &tmp, false);
if (pre_alloc) {
CHECK_EQ(tmp.shape, dst->shape_) << "LoadBinary, shape do not match pre-allocated shape";
} else {
dst->shape = tmp.shape; AllocSpace(dst);
}
Stream<gpu> stream;
Copy(*dst, tmp, &stream);
FreeSpace(&tmp);
}
} // namespace mshadow
#endif // MSHADOW_IO_H_
//===== EXPANDED: ../mshadow/mshadow/io.h =====
//===== EXPANDING: ../mshadow/mshadow/tensor_container.h =====
/*!
* Copyright (c) 2014 by Contributors
* \file tensor_container.h
* \brief tensor container that does memory allocation and resize like STL
* \author Tianqi Chen
*/
#ifndef MSHADOW_TENSOR_CONTAINER_H_
#define MSHADOW_TENSOR_CONTAINER_H_
namespace mshadow {
/*!
* \brief tensor container that does memory allocation and resize like STL,
* use it to save the lines of FreeSpace in class.
* Do not abuse it, efficiency can come from pre-allocation and no re-allocation
*
* \tparam Device which device the tensor is on
* \tparam dimension dimension of the tensor
*/
template<typename Device, int dimension, typename DType = default_real_t>
class TensorContainer: public Tensor<Device, dimension, DType> {
public:
/*!
* \brief constructor
* \param pad whether use padding alignment in space allocation
*/
explicit TensorContainer(bool pad = MSHADOW_ALLOC_PAD) {
this->pad_ = pad;
this->dptr_ = data_.dptr_ = NULL;
this->shape_[0] = 0;
this->stride_ = 0;
this->data_.stride_ = 0;
this->data_.shape_[0] = 0;
}
/*!
* \brief constructor
* \param shape intial shape
*/
explicit TensorContainer(const Shape<dimension> &shape) {
this->pad_ = MSHADOW_ALLOC_PAD;
data_.dptr_ = NULL;
this->AllocByShape(shape);
}
/*!
* \brief constructor
* \param shape intial shape
* \param initv intial value
*/
explicit TensorContainer(const Shape<dimension> &shape, DType initv) {
this->pad_ = MSHADOW_ALLOC_PAD;
data_.dptr_ = NULL;
this->AllocByShape(shape);
(*this) = initv;
}
/*!
* \brief copy constructor
* \param src source value
*/
TensorContainer
(const TensorContainer<Device, dimension, DType> &src)
: pad_(src.pad_) {
this->dptr_ = data_.dptr_ = NULL;
this->shape_[0] = 0;
this->stride_ = 0;
this->data_.stride_ = 0;
this->data_.shape_[0] = 0;
this->stream_ = src.stream_;
if (src.dptr_ != NULL) {
this->AllocByShape(src.shape_);
mshadow::Copy(*this, src, this->stream_);
}
}
~TensorContainer(void) {
this->Release();
}
/*!
* \brief resize the container to given shape, content is NOT preserved
* \param shape target shape
*/
inline void Resize(const Shape<dimension> &shape) {
Shape<2> s2 = shape.FlatTo2D();
if (s2.shape_[1] > data_.stride_ || s2.shape_[0] > data_.size(0)) {
this->AllocByShape(shape);
} else {
this->shape_ = shape;
if (this->pad_) {
this->stride_ = data_.stride_;
} else {
this->stride_ = s2.shape_[1];
}
}
}
/*!
* \brief resize the container to given shape, and initialize, content is NOT preserved
* \param shape target shape
* \param initv initialization value
*/
inline void Resize(const Shape<dimension> &shape, DType initv) {
this->Resize(shape);
(*this) = initv;
}
/*! \brief set whether padding is allowed in tensor */
inline void set_pad(bool pad) {
this->pad_ = pad;
}
/*!
* \brief save by binary format
* \param fo output binary stream
* \tparam TStream type of stream, need to support Read, Write, one example is utils::IStream.
*/
template<typename TStream>
inline void SaveBinary(TStream &fo) const { // NOLINT(*)
mshadow::SaveBinary(fo, *this);
}
/*!
* \brief load by binary format, a temp Tensor<cpu,dim> storage will be allocated
* \param fi input binary stream
* \tparam TStream type of stream, need to support Read, Write, one example is utils::IStream.
*/
template<typename TStream>
inline void LoadBinary(TStream &fi) { // NOLINT(*)
Tensor<cpu, dimension, DType> tmp;
mshadow::LoadBinary(fi, &tmp, false);
this->Resize(tmp.shape_);
Stream<Device> stream;
Copy(*this, tmp, &stream);
mshadow::FreeSpace(&tmp);
}
/*!
* \brief assign operator from TensorContainer
* \param src source value
* \return reference of self
*/
inline TensorContainer &operator=
(const TensorContainer<Device, dimension, DType> &src) {
this->pad_ = src.pad_;
this->stream_ = src.stream_;
if (src.dptr_ != NULL) {
this->Resize(src.shape_);
mshadow::Copy(*this, src, this->stream_);
}
return *this;
}
/*!\brief functions to fit expression template */
inline Tensor<Device, dimension, DType> &operator=(DType s) {
return this->__assign(s);
}
/*!\brief functions to fit expression template */
template<typename E>
inline Tensor<Device, dimension, DType> &
operator=(const expr::Exp<E, DType, expr::type::kMapper> &exp) {
return this->__assign(exp);
}
/*!\brief functions to fit expression template */
template<typename E>
inline Tensor<Device, dimension, DType> &
operator=(const expr::Exp<E, DType, expr::type::kChainer> &exp) {
return this->__assign(exp);
}
/*!\brief functions to fit expression template */
template<typename E>
inline Tensor<Device, dimension, DType> &
operator=(const expr::Exp<E, DType, expr::type::kComplex> &exp) {
return this->__assign(exp);
}
/*!
* \brief Release the llocated space,
* The TensorContainer is still functionable,
* but will restart allocating space when Resize is called.
*/
inline void Release(void) {
if (data_.dptr_ != NULL) {
this->shape_[0] = 0;
this->stride_ = 0;
this->data_.stride_ = 0;
this->data_.shape_[0] = 0;
try {
mshadow::FreeSpace(&data_);
} catch (const dmlc::Error &e) {
this->dptr_ = data_.dptr_ = NULL;
throw e;
}
this->dptr_ = data_.dptr_ = NULL;
}
}
private:
/*! \brief whether we do padding in the space */
bool pad_;
/*! \brief the shape of data_ is actually current data space */
Tensor<Device, 2, DType> data_;
inline void AllocByShape(const Shape<dimension>& shape) {
if (data_.dptr_ != NULL) this->Release();
data_.shape_ = shape.FlatTo2D();
mshadow::AllocSpace(&data_, pad_);
this->dptr_ = data_.dptr_;
this->shape_ = shape;
if (this->pad_) {
this->stride_ = data_.stride_;
} else {
this->stride_ = data_.size(1);
}
}
};
} // namespace mshadow
#endif // MSHADOW_TENSOR_CONTAINER_H_
//===== EXPANDED: ../mshadow/mshadow/tensor_container.h =====
//===== EXPANDING: ../mshadow/mshadow/tensor_blob.h =====
/*!
* Copyright (c) 2014 by Contributors
* \file tensor_blob.h
* \brief TBlob class that holds common representation of
* arbitrary dimension tensor, can be used to transformed
* to normal fixed dimension tensor
* \author Tianqi Chen
*/
#ifndef MSHADOW_TENSOR_BLOB_H_
#define MSHADOW_TENSOR_BLOB_H_
namespace mshadow {
/*!
* \brief dynamic shape class that can hold shape
* of arbitrary dimension
*/
struct TShape {
public:
/*! \brief constructor */
TShape()
: ndim_(0),
num_heap_allocated_(0),
data_heap_(NULL) {}
/*!
* \brief construct an "all-one" TShape with given dimension
* \param ndim the number of dimension of the shape
*/
explicit TShape(index_t ndim)
: ndim_(ndim) {
if (ndim_ <= kStackCache) {
data_heap_ = NULL;
num_heap_allocated_ = 0;
std::fill_n(data_stack_, ndim_, 1);
} else {
data_heap_ = new index_t[ndim_];
num_heap_allocated_ = ndim_;
std::fill_n(data_heap_, ndim_, 1);
}
}
/*!
* \brief constructor from TShape
* \param s the source shape
*/
TShape(const TShape &s)
: ndim_(s.ndim_) {
if (ndim_ <= kStackCache) {
data_heap_ = NULL;
num_heap_allocated_ = 0;
std::copy(s.data_stack_, s.data_stack_ + ndim_, data_stack_);
} else {
data_heap_ = new index_t[ndim_];
num_heap_allocated_ = ndim_;
std::copy(s.data_heap_, s.data_heap_ + ndim_, data_heap_);
}
}
/*!
* \brief construct the TShape from content of iterator
* \param begin the beginning of iterator
* \param end end the end of the iterator
* \tparam RandomAccessIterator iterator type
*/
template<typename RandomAccessIterator>
TShape(RandomAccessIterator begin,
RandomAccessIterator end)
: ndim_(0),
num_heap_allocated_(0),
data_heap_(NULL) {
this->CopyFrom(begin, end);
}
#if MSHADOW_IN_CXX11
/*!
* \brief move constructor from TShape
* \param s the source shape
*/
TShape(TShape &&s)
: ndim_(s.ndim_),
num_heap_allocated_(s.num_heap_allocated_),
data_heap_(s.data_heap_) {
if (ndim_ <= kStackCache) {
std::copy(s.data_stack_, s.data_stack_ + ndim_, data_stack_);
}
// remove data heap space from s
s.data_heap_ = NULL;
}
/*!
* \brief move constructor from Shape
* \param s the source shape
*/
template<int dim>
TShape(Shape<dim> &&s) // NOLINT(*)
: ndim_(0),
num_heap_allocated_(0),
data_heap_(NULL) {
this->CopyFrom(s.shape_, s.shape_ + dim);
}
#endif
/*! \brief destructor */
~TShape() {
// data_heap_ can be NULL
delete [] data_heap_;
}
/*!
* \brief copy shape from content betwen two iterators
* \param begin the beginning of iterator
* \param end the end of the iterator
* \tparam RandomAccessIterator iterator type
*/
template<typename RandomAccessIterator>
inline void CopyFrom(RandomAccessIterator begin,
RandomAccessIterator end) {
this->SetDim(end - begin);
std::copy(begin, end, data());
}
/*!
* \brief assignment from shape
* \param shape source shape
* \return reference of self
*/
inline TShape &operator=(const TShape &shape) {
this->SetDim(shape.ndim_);
const index_t *src = shape.data();
std::copy(src, src + ndim_, data());
return *this;
}
/*!
* \brief assignment from vector
* \param shape source shape
* \return reference of self
*/
inline TShape &operator=(const std::vector<index_t> &shape) {
this->CopyFrom(shape.begin(), shape.end());
return *this;
}
/*!
* \brief assignment from shape
* \param shape source shape
* \tparam dim shape dimension
* \return reference of self
*/
template<int dim>
inline TShape &operator=(const Shape<dim> &shape) {
this->SetDim(dim);
index_t *d = dim <= kStackCache ? data_stack_ : data_heap_;
for (int i = 0; i < dim; ++i) {
d[i] = shape[i];
}
return *this;
}
/*! \return the data content of the shape */
inline const index_t *data() const {
return ndim_ <= kStackCache ? data_stack_ : data_heap_;
}
/*! \return the data content of the shape */
inline index_t *data() {
return ndim_ <= kStackCache ? data_stack_ : data_heap_;
}
/*! \brief return number of dimension of the tensor inside */
inline index_t ndim(void) const {
return ndim_;
}
/*!
* \brief get corresponding index
* \param i dimension index
* \return the corresponding dimension size
*/
inline index_t &operator[](index_t i) {
return data()[i];
}
/*!
* \brief get corresponding index
* \param i dimension index
* \return the corresponding dimension size
*/
inline const index_t &operator[](index_t i) const {
return data()[i];
}
/*! \brief total number of elements in the tensor */
inline size_t Size(void) const {
size_t size = 1;
const index_t *d = this->data();
for (index_t i = 0; i < ndim_; ++i) {
size *= d[i];
}
return size;
}
/*!
* flatten the higher dimension to second dimension, return a 2D shape
* \return the flat 2d shape
*/
inline Shape<2> FlatTo2D(void) const {
Shape<2> s;
if (ndim_ == 0) return Shape2(0, 0);
const index_t *d = this->data();
s.shape_[1] = d[ndim_ - 1];
index_t ymax = 1;
for (index_t i = 1; i < ndim_; ++i) {
ymax *= d[i - 1];
}
s.shape_[0] = ymax;
return s;
}
/*!
* flatten the shape into three parts: [0, axis_begin), [axis_begin, axis_end], (axis_end, ndim)
* \param axis_begin The beginning axis specified.
* \param axis_end The ending axis specified.
* \return the flat 3d shape
*/
inline Shape<3> FlatTo3D(index_t axis_begin, index_t axis_end) const {
CHECK(axis_end >= axis_begin);
Shape<3> s;
if (ndim_ == 0) return Shape3(0, 0, 0);
const index_t *d = this->data();
s.shape_[0] = 1;
s.shape_[1] = 1;
s.shape_[2] = 1;
for (index_t i = 0; i < axis_begin; ++i) {
s.shape_[0] *= d[i];
}
for (index_t i = axis_begin; i <= axis_end; ++i) {
s.shape_[1] *= d[i];
}
for (index_t i = axis_end + 1; i < ndim_; ++i) {
s.shape_[2] *= d[i];
}
return s;
}
/*!
* flatten the axis before and after the specified axis, so it becomes 3D tensor
* \param axis The axis specified.
* \return the flat 3d shape
*/
inline Shape<3> FlatTo3D(index_t axis) const {
return FlatTo3D(axis, axis);
}
/*!
* \return product shape in [dimstart,dimend)
* \param dimstart start dimension
* \param dimend end dimension
*/
inline index_t ProdShape(int dimstart, int dimend) const {
index_t num = 1;
const index_t *d = this->data();
for (int i = dimstart; i < dimend; ++i) {
num *= d[i];
}
return num;
}
/*!
* \brief get the shape of tensor specifying dim
* \return the shape requested
* \tparam dim dimension of the tensor
*/
template<int dim>
inline Shape<dim> get(void) const {
CHECK_EQ(dim, ndim_) << "dimension do not match target dimension " << dim << " vs " << ndim_;
const index_t *d = this->data();
Shape<dim> s;
for (int i = 0; i < dim; ++i) {
s[i] = d[i];
}
return s;
}
/*!
* \return whether two shape equals
* \param s the shape to compare against
*/
inline bool operator==(const TShape &s) const {
if (ndim_ != s.ndim_) return false;
if (ndim_ <= kStackCache) {
for (index_t i = 0; i < ndim_; ++i) {
if (data_stack_[i] != s.data_stack_[i]) return false;
}
} else {
for (index_t i = 0; i < ndim_; ++i) {
if (data_heap_[i] != s.data_heap_[i]) return false;
}
}
return true;
}
/*!
* \return whether two shape not equals
* \param s the shape to compare against
*/
inline bool operator!=(const TShape &s) const {
return !(*this == s);
}
/*!
* \return whether two shape equals
* \param s the shape to compare against
* \tparam dim dimension of the shape
*/
template<int dim>
inline bool operator==(const Shape<dim> &s) const {
if (ndim_ != dim) return false;
const index_t *d = dim <= kStackCache ? data_stack_ : data_heap_;
for (index_t i = 0; i < dim; ++i) {
if (d[i] != s.shape_[i]) return false;
}
return true;
}
/*!
* \return whether two shape not equals
* \param s the shape to compare against
* \tparam dim dimension of the shape
*/
template<int dim>
inline bool operator!=(const Shape<dim> &s) const {
return !(*this == s);
}
/*!
* \brief save the content into binary stream
* \param strm the output stream
* \tparam TStream any stream type that have write
*/
template<typename TStream>
inline void Save(TStream *strm) const {
strm->Write(&ndim_, sizeof(ndim_));
strm->Write(data(), sizeof(index_t) * ndim_);
}
/*!
* \brief load the content from binary stream
* \param strm the output stream
* \tparam TStream any stream type that have write
* \return whether the load is successful
*/
template<typename TStream>
inline bool Load(TStream *strm) {
if (strm->Read(&ndim_, sizeof(ndim_)) != sizeof(ndim_)) return false;
this->SetDim(ndim_);
size_t nread = sizeof(index_t) * ndim_;
if (strm->Read(data(), nread) != nread) return false;
return true;
}
friend std::ostream &operator<<(std::ostream &os, const TShape &shape);
friend std::istream &operator>>(std::istream &is, TShape &shape);
private:
// the shape will be stored in data_stack_
// when dimension is smaller than kStackCache
// when it is bigger, it will be stored in data_heap_;
/*! \brief size of in stack space */
static const index_t kStackCache = 4;
/*! \brief number of dimension of the shape */
index_t ndim_;
/*! \brief number of cells allocated in data_heap_ */
index_t num_heap_allocated_;
/*! \brief in stack space used to store shape when it is small */
index_t data_stack_[kStackCache];
/*! \brief space to store shape when dimension is big*/
index_t *data_heap_;
/*!
* \brief internal function to set the dimension
* \param dim the dimension of the shape
*/
inline void SetDim(index_t dim) {
if (dim > kStackCache &&
dim > num_heap_allocated_) {
// data_heap_ can be NULL
delete [] data_heap_;
data_heap_ = new index_t[dim];
num_heap_allocated_ = dim;
}
ndim_ = dim;
}
};
/*!
* \brief allow string printing of the shape
* \param os the output stream
* \param shape the shape
* \return the ostream
*/
inline std::ostream &operator<<(std::ostream &os, const TShape &shape) {
os << '(';
for (index_t i = 0; i < shape.ndim(); ++i) {
if (i != 0) os << ',';
os << shape[i];
}
// python style tuple
if (shape.ndim() == 1) os << ',';
os << ')';
return os;
}
/*!
* \brief read shape from the istream
* \param is the input stream
* \param shape the shape
* \return the istream
*/
inline std::istream &operator>>(std::istream &is, TShape &shape) {
// get (
while (true) {
char ch = is.peek();
if (isdigit(ch)) {
index_t idx;
if (is >> idx) {
shape.CopyFrom(&idx, &idx + 1);
}
return is;
}
is.get();
if (ch == '(') break;
if (!isspace(ch)) {
is.setstate(std::ios::failbit);
return is;
}
}
index_t idx;
std::vector<index_t> tmp;
while (is >> idx) {
tmp.push_back(idx);
char ch;
do {
ch = is.get();
} while (isspace(ch));
if (ch == 'L') {
ch = is.get();
}
if (ch == ',') {
while (true) {
ch = is.peek();
if (isspace(ch)) {
is.get(); continue;
}
if (ch == ')') {
is.get(); break;
}
break;
}
if (ch == ')') break;
} else if (ch == ')') {
break;
} else {
is.setstate(std::ios::failbit);
return is;
}
}
shape.CopyFrom(tmp.begin(), tmp.end());
return is;
}
/*!
* \brief tensor blob class that can be used to hold tensor of any dimension,
* any device and any data type,
* This is a weak type that can be used to transfer data through interface
* TBlob itself do not involve any arithmetic operations,
* but it can be converted to tensor of fixed dimension for further operations
*
* Like tensor, this data structure is like a pointer class and do not
* implicit allocated, de-allocate space.
* This data structure can be helpful to hold tensors of different dimensions
* and wait for further processing
*/
class TBlob {
public:
/*! \brief pointer to the data */
void *dptr_;
/*! \brief shape of the tensor */
TShape shape_;
/*!
* \brief storing the stride information in x dimension
*/
index_t stride_;
/*! \brief device mask of the corresponding device */
int dev_mask_;
/*! \brief type flag of the tensor blob */
int type_flag_;
/*! \brief default constructor, default copy assign will work */
TBlob(void)
: dptr_(NULL), dev_mask_(cpu::kDevMask),
type_flag_(DataType<default_real_t>::kFlag) {}
/*!
* \brief constructor that construct TBlob from contiguous memory
* \param dptr the pointer to the memory
* \param shape the shape of the data
* \param dev_mask the device mask, can be cpu::kDevMask or gpu::kDevMask
*/
template<typename DType>
TBlob(DType *dptr,
const TShape &shape,
int dev_mask)
: dptr_(dptr), shape_(shape),
stride_(shape[shape.ndim() - 1]),
dev_mask_(dev_mask),
type_flag_(DataType<DType>::kFlag) {}
/*!
* \brief constructor that construct TBlob from contiguous memory
* \param dptr the pointer to the memory
* \param shape the shape of the data
* \param dev_mask the device mask, can be cpu::kDevMask or gpu::kDevMask
* \param type_flag the type flag. Can be one of enum mshadow::dtype
*/
TBlob(void *dptr,
const TShape &shape,
int dev_mask,
int type_flag)
: dptr_(dptr), shape_(shape),
stride_(shape[shape.ndim() - 1]),
dev_mask_(dev_mask),
type_flag_(type_flag) {}
/*!
* \brief constructor from tensor
* \param src source tensor
* \tparam Device which device the tensor is on
* \tparam dim tensor dimension
* \tparam DType the type of elements in the tensor
*/
template<typename Device, int dim, typename DType>
TBlob(const Tensor<Device, dim, DType> &src) { // NOLINT(*)
*this = src;
}
/*!
* \brief assignment from tensor
* \param src source tensor
* \tparam Device which device the tensor is on
* \tparam dim tensor dimension
* \tparam DType the type of elements in the tensor
* \return reference of self
*/
template<typename Device, int dim, typename DType>
inline TBlob
&operator=(const Tensor<Device, dim, DType> &src) {
dptr_ = src.dptr_;
shape_ = src.shape_;
stride_ = src.stride_;
dev_mask_ = Device::kDevMask;
type_flag_ = DataType<DType>::kFlag;
return *this;
}
/*!
* \return whether the tensor's memory is continuous
*/
inline bool CheckContiguous(void) const {
return shape_[shape_.ndim() - 1] == stride_;
}
/*!
* \brief flatten the tensor to 2 dimension, collapse the higher dimensions together
* \param stream the possible stream target tensor should reside on
* \tparam Device which device the tensor is on
* \tparam DType the type of elements in the tensor
* \return tensor after flatten
*/
template<typename Device, typename DType>
inline Tensor<Device, 2, DType> FlatTo2D(Stream<Device> *stream = NULL) const {
CHECK(Device::kDevMask == dev_mask_)
<< "TBlob.get: device type do not match specified type";
CHECK(DataType<DType>::kFlag == type_flag_)
<< "TBlob.get_with_shape: data type do not match specified type."
<< "Expected: " << type_flag_ << " v.s. given " << DataType<DType>::kFlag;
return Tensor<Device, 2, DType>(static_cast<DType*>(dptr_),
shape_.FlatTo2D(), stride_, stream);
}
/*! \brief return number of dimension of the tensor inside */
inline int ndim(void) const {
return shape_.ndim();
}
/*!
* \brief return size of i-th dimension, start counting from highest dimension
* \param idx the dimension count from the highest dimension
* \return the size
*/
inline index_t size(index_t idx) const {
return shape_[idx];
}
/*! \brief total number of elements in the tensor */
inline index_t Size(void) const {
return shape_.Size();
}
/*!
* \brief fetch the tensor, with respect to specific dimension
* if dim do not match the stored dimension, an error will be issued
* \return the tensor requested
* \param stream the possible stream target tensor should reside on
* \tparam Device which device the tensor is on
* \tparam dim dimension of the tensor
* \tparam DType the type of elements in the tensor
*/
template<typename Device, int dim, typename DType>
inline Tensor<Device, dim, DType> get(Stream<Device> *stream = NULL) const {
CHECK(Device::kDevMask == dev_mask_)
<< "TBlob.get: device type do not match specified type";
CHECK(DataType<DType>::kFlag == type_flag_)
<< "TBlob.get_with_shape: data type do not match specified type."
<< "Expected: " << type_flag_ << " v.s. given " << DataType<DType>::kFlag;
return Tensor<Device, dim, DType>(static_cast<DType*>(dptr_),
shape_.get<dim>(),
stride_, stream);
}
/*!
* \brief fetch a tensor in given shape
* If size do not match the stored size, an error will be issued
* \return the tensor requested
* \param shape the shape required
* \param stream the possible stream target tensor should reside on
* \tparam Device which device the tensor is on
* \tparam dim dimension of the tensor
* \tparam DType the type of elements in the tensor
*/
template<typename Device, int dim, typename DType>
inline Tensor<Device, dim, DType> get_with_shape(const Shape<dim> &shape,
Stream<Device> *stream = NULL) const {
CHECK(Device::kDevMask == dev_mask_)
<< "TBlob.get: device type do not match specified type";
CHECK(DataType<DType>::kFlag == type_flag_)
<< "TBlob.get_with_shape: data type do not match specified type."
<< "Expected: " << type_flag_ << " v.s. given " << DataType<DType>::kFlag;
CHECK_EQ(this->CheckContiguous(), true) << "TBlob.get_reshape: must be contiguous";
CHECK_EQ(this->shape_.Size(), shape.Size())
<< "TBlob.get_with_shape: new and old shape do not match total elements";
return Tensor<Device, dim, DType>(static_cast<DType*>(dptr_),
shape,
shape[dim - 1],
stream);
}
/*!
* \brief flatten the tensor to 3 dimension,
* collapse the dimension before and after specified axis.
* \param axis The axis specified.
* \param stream the possible stream target tensor should reside on
* \tparam Device which device the tensor is on
* \tparam DType the type of elements in the tensor
* \return tensor after flatten
*/
template<typename Device, typename DType>
inline Tensor<Device, 3, DType> FlatTo3D(int axis, Stream<Device> *stream = NULL) const {
return this->get_with_shape<Device, 3, DType>(
this->shape_.FlatTo3D(axis), stream);
}
/*!
* \brief flatten the tensor to 3 dimension,
* collapse the dimension: [0, axis_begin), [axis_begin, axis_end], (axis_end, ndim).
* \param axis_begin The beginning axis specified.
* \param axis_end The ending axis specified.
* \param stream the possible stream target tensor should reside on
* \tparam Device which device the tensor is on
* \tparam DType the type of elements in the tensor
* \return tensor after flatten
*/
template<typename Device, typename DType>
inline Tensor<Device, 3, DType> FlatTo3D(int axis_begin, int axis_end,
Stream<Device> *stream = NULL) const {
return this->get_with_shape<Device, 3, DType>(
this->shape_.FlatTo3D(axis_begin, axis_end), stream);
}
};
} // namespace mshadow
#endif // MSHADOW_TENSOR_BLOB_H_
//===== EXPANDED: ../mshadow/mshadow/tensor_blob.h =====
//===== EXPANDING: ../mshadow/mshadow/random.h =====
/*!
* Copyright (c) 2014 by Contributors
* \file random.h
* \brief Random inline functions for tensor.
* \author Bing Xu, Tianqi Chen
* Based on curand|MKL|stdlib
*/
#ifndef MSHADOW_RANDOM_H_
#define MSHADOW_RANDOM_H_
#if MSHADOW_IN_CXX11
#endif
#if _MSC_VER
#define rand_r(x) rand()
#endif
namespace mshadow {
/*!
* \brief random number generator
* \tparam Device the device of random number generator
* \tparam DType the target data type of random number can be float for double
*/
template<typename Device, typename DType MSHADOW_DEFAULT_DTYPE>
class Random {};
/*! \brief CPU random number generator */
template<typename DType>
class Random<cpu, DType> {
public:
/*!
* \brief constructor of random engine
* \param seed random number seed
*/
explicit Random(int seed) {
this->Seed(seed);
buffer_.Resize(Shape1(kRandBufferSize));
}
~Random(void) {
}
/*!
* \brief seed random number generator using this seed
* \param seed seed of prng
*/
inline void Seed(int seed) {
#if MSHADOW_IN_CXX11
rnd_engine_.seed(seed);
#endif
this->rseed_ = static_cast<unsigned>(seed);
}
/*!
* \brief get random seed used in random generator
* \return seed in unsigned
*/
inline unsigned GetSeed() const {
return rseed_;
}
/*!
* \brief set the stream of computation
* \param stream computation stream
*/
inline void set_stream(Stream<cpu> *stream) {
}
/*!
* \brief generate data from uniform [a,b)
* \param dst destination
* \param a lower bound of uniform
* \param b upper bound of uniform
* \tparam dim dimension of tensor
*/
template<int dim>
inline void SampleUniform(Tensor<cpu, dim, DType> *dst,
DType a = 0.0f, DType b = 1.0f) {
if (dst->CheckContiguous()) {
this->GenUniform(dst->dptr_, dst->shape_.Size(), a, b);
} else {
Tensor<cpu, 2, DType> mat = dst->FlatTo2D();
for (index_t i = 0; i < mat.size(0); ++i) {
this->GenUniform(mat[i].dptr_, mat.size(1), a, b);
}
}
}
/*!
* \brief generate data from standard gaussian
* \param dst destination
* \param mu mean variable
* \param sigma standard deviation
* \tparam dim dimension of tensor
*/
template<int dim>
inline void SampleGaussian(Tensor<cpu, dim, DType> *dst,
DType mu = 0.0f, DType sigma = 1.0f) {
if (sigma <= 0.0f) {
*dst = mu; return;
}
if (dst->CheckContiguous()) {
this->GenGaussian(dst->dptr_, dst->shape_.Size(), mu, sigma);
} else {
Tensor<cpu, 2, DType> mat = dst->FlatTo2D();
for (index_t i = 0; i < mat.size(0); ++i) {
this->GenGaussian(mat[i].dptr_, mat.size(1), mu, sigma);
}
}
}
/*!
* \brief return a temporal expression storing standard gaussian random variables
* the temporal tensor is only valid before next call of gaussian or uniform
* can be used as part of expression
* Caution: this means expression such as A = gaussian(s1) * gaussian(s2) will give invalid result,
* since second call of gaussian(s2) makes gaussian(s1) invalid
* A = gaussian(s1)*B+C; is correct; use one gaussian/uniform in each expression
* \param shape shape of the tensor
* \return a temporal expression storing standard gaussian random variables
* \tparam dim dimension of tensor
*/
template<int dim>
inline expr::ReshapeExp<Tensor<cpu, 1, DType>, DType, dim, 1>
gaussian(Shape<dim> shape) {
buffer_.Resize(Shape1(shape.Size()));
this->SampleGaussian(&buffer_, 0.0f, 1.0f);
return expr::reshape(buffer_, shape);
}
/*!
* \brief return a temporal expression storing standard uniform [0,1)
* the temporal tensor is only valid before next call of gaussian or uniform
* can be used as part of expression
* Caution: this means expression such as A = uniform(s1) * uniform(s2) will give invalid result,
* since second call of gaussian(s2) makes gaussian(s1) invalid
* A = gaussian(s1)*B+C; is correct; use one gaussian/uniform in each expression
* \param shape shape of the tensor
* \return a temporal expression storing standard uniform [0,1)
* \tparam dim dimension of tensor
*/
template<int dim>
inline expr::ReshapeExp<Tensor<cpu, 1, DType>, DType, dim, 1>
uniform(Shape<dim> shape) {
buffer_.Resize(Shape1(shape.Size()));
this->SampleUniform(&buffer_, 0.0f, 1.0f);
return expr::reshape(buffer_, shape);
}
private:
#if MSHADOW_IN_CXX11
/*! \brief use c++11 random engine. */
std::mt19937 rnd_engine_;
/*! \brief random number seed used in random engine */
unsigned rseed_;
// implementing generators.
inline void GenUniform(DType *dptr, index_t size, DType a, DType b) {
std::uniform_real_distribution<DType> dist_uniform(a, b);
for (size_t i = 0; i < size; ++i) {
dptr[i] = dist_uniform(rnd_engine_);
}
}
inline void GenGaussian(DType *dptr, index_t size, DType mu, DType sigma) {
std::normal_distribution<DType> dist_normal(mu, sigma);
for (size_t i = 0; i < size; ++i) {
dptr[i] = dist_normal(rnd_engine_);
}
}
#else
/*! \brief random number seed used by PRNG */
unsigned rseed_;
// functions
inline void GenUniform(float *dptr, index_t size, float a, float b) {
for (index_t j = 0; j < size; ++j) {
dptr[j] = static_cast<float>(RandNext()) * (b - a) + a;
}
}
inline void GenUniform(double *dptr, index_t size, double a, double b) {
for (index_t j = 0; j < size; ++j) {
dptr[j] = static_cast<double>(RandNext()) * (b - a) + a;
}
}
inline void GenGaussian(float *dptr, index_t size, float mu, float sigma) {
this->GenGaussianX(dptr, size, mu, sigma);
}
inline void GenGaussian(double *dptr, index_t size, double mu, double sigma) {
this->GenGaussianX(dptr, size, mu, sigma);
}
inline void GenGaussianX(DType *dptr, index_t size, DType mu, DType sigma) {
DType g1 = 0.0f, g2 = 0.0f;
for (index_t j = 0; j < size; ++j) {
if ((j & 1) == 0) {
this->SampleNormal2D(&g1, &g2);
dptr[j] = mu + g1 * sigma;
} else {
dptr[j] = mu + g2 * sigma;
}
}
}
/*! \brief get next random number from rand */
inline DType RandNext(void) {
return static_cast<DType>(rand_r(&rseed_)) /
(static_cast<DType>(RAND_MAX) + 1.0f);
}
/*! \brief return a real numer uniform in (0,1) */
inline DType RandNext2(void) {
return (static_cast<DType>(rand_r(&rseed_)) + 1.0f) /
(static_cast<DType>(RAND_MAX) + 2.0f);
}
/*!
* \brief sample iid xx,yy ~N(0,1)
* \param xx first gaussian output
* \param yy second gaussian output
*/
inline void SampleNormal2D(DType *xx_, DType *yy_) {
DType &xx = *xx_, &yy = *yy_;
DType x, y, s;
do {
x = 2.0f * RandNext2() - 1.0f;
y = 2.0f * RandNext2() - 1.0f;
s = x * x + y * y;
} while (s >= 1.0f || s == 0.0f);
DType t = std::sqrt(-2.0f * std::log(s) / s);
xx = x * t; yy = y * t;
}
#endif
/*! \brief temporal space used to store random numbers */
TensorContainer<cpu, 1, DType> buffer_;
}; // class Random<cpu, DType>
// only allow GPU PRNG when cuda is enabled
#if MSHADOW_USE_CUDA
/*! \brief GPU random number generator */
template<typename DType>
class Random<gpu, DType> {
public:
/*!
* \brief constructor of random engine
* \param seed random number seed
*/
explicit Random(int seed) {
curandStatus_t status;
status = curandCreateGenerator(&gen_, CURAND_RNG_PSEUDO_DEFAULT);
CHECK_EQ(status, CURAND_STATUS_SUCCESS) << "Can not create CURAND Generator";
this->Seed(seed);
buffer_.Resize(Shape1(kRandBufferSize));
}
~Random(void) MSHADOW_THROW_EXCEPTION {
curandStatus_t status;
status = curandDestroyGenerator(gen_);
CHECK_EQ(status, CURAND_STATUS_SUCCESS) << "Destory CURAND Gen failed";
}
/*!
* \brief set the stream of computation
* \param stream computation stream
*/
inline void set_stream(Stream<gpu> *stream) {
curandStatus_t status;
status = curandSetStream(gen_, Stream<gpu>::GetStream(stream));
CHECK_EQ(status, CURAND_STATUS_SUCCESS) << "set_stream CURAND failed";
}
/*!
* \brief seed random number generator using this seed
* \param seed seed of prng
*/
inline void Seed(int seed) {
curandStatus_t status;
status = curandSetPseudoRandomGeneratorSeed(gen_, seed);
CHECK_EQ(status, CURAND_STATUS_SUCCESS) << "Set CURAND seed failed.";
}
/*!
* \brief generate data from uniform [a,b)
* \param dst destination
* \param a lower bound of uniform
* \param b upper bound of uniform
* \tparam dim dimension of tensor
*/
template<int dim>
inline void SampleUniform(Tensor<gpu, dim, DType> *dst,
DType a = 0.0f, DType b = 1.0f);
/*!
* \brief generate data from standard gaussian
* \param dst destination
* \param mu mean variable
* \param sigma standard deviation
* \tparam dim dimension of tensor
*/
template<int dim>
inline void SampleGaussian(Tensor<gpu, dim, DType> *dst,
DType mu = 0.0f, DType sigma = 1.0f);
/*!
* \brief return a temporal expression storing standard gaussian random variables
* the temporal tensor is only valid before next call of gaussian or uniform
* can be used as part of expression
* Caution: this means expression such as A = gaussian(s1) * gaussian(s2) will give invalid result,
* since second call of gaussian(s2) makes gaussian(s1) invalid
* A = gaussian(s1)*B+C; is correct; use one gaussian/uniform in each expression
* \param shape shape of the tensor
* \param mu mean
* \param sigma variance
* \return a temporal expression storing standard gaussian random variables
* \tparam dim dimension of tensor
*/
template<int dim>
inline expr::ReshapeExp<Tensor<gpu, 1, DType>, DType, dim, 1>
gaussian(Shape<dim> shape, DType mu = 0.0f, DType sigma = 1.0f);
/*!
* \brief return a temporal expression storing standard uniform [0,1)
* the temporal tensor is only valid before next call of gaussian or uniform
* can be used as part of expression
* Caution: this means expression such as A = gaussian(s1) * gaussian(s2) will give invalid result,
* since second call of gaussian(s2) makes gaussian(s1) invalid
* A = gaussian(s1)*B+C; is correct; use one gaussian/uniform in each expression
* \param shape shape of the tensor
* \return a temporal expression storing standard uniform [0,1)
* \tparam dim dimension of tensor
*/
template<int dim>
inline expr::ReshapeExp<Tensor<gpu, 1, DType>, DType, dim, 1>
uniform(Shape<dim> shape);
private:
inline void GenGaussian(float *dptr, size_t size, float mu, float sigma) {
curandStatus_t status;
status = curandGenerateNormal(gen_, dptr, size, mu, sigma);
CHECK_EQ(status, CURAND_STATUS_SUCCESS) << "CURAND Gen Normal float failed."
<< " size = " << size
<< ",mu = " << mu
<< ",sigma = " << sigma;
}
inline void GenGaussian(double *dptr, size_t size, double mu, double sigma) {
curandStatus_t status;
status = curandGenerateNormalDouble(gen_, dptr, size, mu, sigma);
CHECK_EQ(status, CURAND_STATUS_SUCCESS) << "CURAND Gen Normal double failed."
<< " size = " << size
<< ",mu = " << mu
<< ",sigma = " << sigma;
}
inline void GenUniform(float *dptr, size_t size) {
curandStatus_t status;
status = curandGenerateUniform(gen_, dptr, size);
CHECK_EQ(status, CURAND_STATUS_SUCCESS) << "CURAND Gen Uniform float failed."
<< " size = " << size;
}
inline void GenUniform(double *dptr, size_t size) {
curandStatus_t status;
status = curandGenerateUniformDouble(gen_, dptr, size);
CHECK_EQ(status, CURAND_STATUS_SUCCESS) << "CURAND Gen Uniform double failed."
<< " size = " << size;
}
/*! \brief random numbeer generator */
curandGenerator_t gen_;
/*! \brief templ buffer */
TensorContainer<gpu, 1, DType> buffer_;
}; // class Random<gpu, DType>
#endif // MSHADOW_USE_CUDA
#ifdef __CUDACC__
// implementations that depends on cuda kernels
template<typename DType>
template<int dim>
inline void Random<gpu, DType>::SampleUniform(
Tensor<gpu, dim, DType> *dst, DType a, DType b) {
if (a == 0.0f && b == 1.0f) {
if (dst->CheckContiguous()) {
this->GenUniform(dst->dptr_, dst->shape_.Size());
} else {
*dst = this->uniform(dst->shape_);
}
} else {
*dst = this->uniform(dst->shape_) * (b - a) + a;
}
}
template<typename DType>
template<int dim>
inline void Random<gpu, DType>::SampleGaussian(
Tensor<gpu, dim, DType> *dst, DType mu, DType sigma) {
// We need to check whether the shape size is even since CuRand supports only normal distribution
// generation of even number of elements.
if (dst->CheckContiguous() && (dst->shape_.Size() % 2 == 0)) {
this->GenGaussian(dst->dptr_, dst->shape_.Size(), mu, sigma);
} else {
*dst = this->gaussian(dst->shape_, mu, sigma);
}
}
template<typename DType>
template<int dim>
inline expr::ReshapeExp<Tensor<gpu, 1, DType>, DType, dim, 1>
Random<gpu, DType>::gaussian(Shape<dim> shape, DType mu, DType sigma) {
size_t aligned_sz = ((shape.Size() + 1UL) >> 1) << 1;
// allocate alligned size
buffer_.Resize(Shape1(aligned_sz));
buffer_.Resize(Shape1(shape.Size()));
this->GenGaussian(buffer_.dptr_, aligned_sz, mu, sigma);
return expr::reshape(buffer_, shape);
}
template<typename DType>
template<int dim>
inline expr::ReshapeExp<Tensor<gpu, 1, DType>, DType, dim, 1>
Random<gpu, DType>::uniform(Shape<dim> shape) {
buffer_.Resize(Shape1(shape.Size()));
this->GenUniform(buffer_.dptr_, buffer_.size(0));
return expr::reshape(buffer_, shape);
}
#endif // __CUDACC__
} // namespace mshadow
#endif // MSHADOW_RANDOM_H_
//===== EXPANDED: ../mshadow/mshadow/random.h =====
// add definition of scalar related operators
#ifdef MSHADOW_SCALAR_
#error "MSHADOW_SCALAR_ must not be defined"
#endif
// enumerate all the scalar data type we aim to be good at
#define MSHADOW_SCALAR_ float
//===== EXPANDING: ../mshadow/mshadow/expr_scalar-inl.h =====
/*!
* Copyright (c) 2014 by Contributors
* \file expr_scalar-inl.h
* \brief definitions of operators in expression with respect to scalar
* this file will be included several times, each time with MACRO MSHADOW_SCALAR_ to be different types
*
* DO NOT add pragma once or macro guard
* \author Tianqi Chen, Bing Xu
*/
// macro guard is harmful, used to pass the cpplint
#ifndef MSHADOW_EXPR_SCALAR_INL_H_
#define MSHADOW_EXPR_SCALAR_INL_H_
// undef the guard so it can be included multiple times
#undef MSHADOW_EXPR_SCALAR_INL_H_
namespace mshadow {
namespace expr {
// DotExp
/*! \brief dot operator def */
template<typename TA, typename TB, bool ltrans, bool rtrans>
inline DotExp<TA, TB, ltrans, rtrans, MSHADOW_SCALAR_>
operator*(const DotExp<TA, TB, ltrans, rtrans, MSHADOW_SCALAR_> &lhs,
MSHADOW_SCALAR_ rhs) {
return DotExp<TA, TB, ltrans, rtrans,
MSHADOW_SCALAR_>(lhs.lhs_, lhs.rhs_, lhs.scale_ * rhs);
}
/*! \brief scale of dot operation */
template<typename TA, typename TB, bool ltrans, bool rtrans>
inline DotExp<TA, TB, ltrans, rtrans, MSHADOW_SCALAR_>
operator*(MSHADOW_SCALAR_ lhs,
const DotExp<TA, TB, ltrans, rtrans, MSHADOW_SCALAR_> &rhs) {
return DotExp<TA, TB, ltrans, rtrans,
MSHADOW_SCALAR_>(rhs.lhs_, rhs.rhs_, rhs.scale_ * lhs);
}
/*! \brief operator overload */
template<typename E, typename DType, typename R, int d>
inline ReduceTo1DExp<E, DType, R, d>
operator*(const ReduceTo1DExp<E, DType, R, d> &e, MSHADOW_SCALAR_ scale) {
return ReduceTo1DExp<E, DType, R, d>(e.src_, e.scale_ * scale);
}
/*! \brief operator overload */
template<typename E, typename DType, typename R, int d>
inline ReduceTo1DExp<E, DType, R, d>
operator*(MSHADOW_SCALAR_ scale, const ReduceTo1DExp<E, DType, R, d> &e) {
return ReduceTo1DExp<E, DType, R, d>(e.src_, e.scale_ * scale);
}
/*! \brief operator overload for const */
template<typename OP, typename TA, int ta>
inline BinaryMapExp<OP, TA, ScalarExp<MSHADOW_SCALAR_>,
MSHADOW_SCALAR_, (ta|type::kMapper)>
F(const Exp<TA, MSHADOW_SCALAR_, ta> &lhs, const ScalarExp<MSHADOW_SCALAR_> &rhs) {
return MakeExp<OP>(lhs, rhs);
}
/*! \brief operator overload for const */
template<typename OP, typename TB, int tb>
inline BinaryMapExp<OP, ScalarExp<MSHADOW_SCALAR_>, TB,
MSHADOW_SCALAR_, (tb|type::kMapper)>
F(const ScalarExp<MSHADOW_SCALAR_> &lhs, const Exp<TB, MSHADOW_SCALAR_, tb> &rhs) {
return MakeExp<OP>(lhs, rhs);
}
/*! \brief operator overload for const */
template<typename OP>
inline BinaryMapExp<OP, ScalarExp<MSHADOW_SCALAR_>, ScalarExp<MSHADOW_SCALAR_>,
MSHADOW_SCALAR_, (1|type::kMapper)>
F(const ScalarExp<MSHADOW_SCALAR_> &lhs, const ScalarExp<MSHADOW_SCALAR_> &rhs) {
return MakeExp<OP>(lhs, rhs);
}
// constant operators
/*! \brief operator overload */
template<typename TA, int ta>
inline BinaryMapExp<op::plus, TA, ScalarExp<MSHADOW_SCALAR_>,
MSHADOW_SCALAR_, (ta|type::kMapper)>
operator+(const Exp<TA, MSHADOW_SCALAR_, ta> &lhs,
const ScalarExp<MSHADOW_SCALAR_> &rhs) {
return MakeExp<op::plus>(lhs, rhs);
}
/*! \brief operator overload */
template<typename TA, int ta>
inline BinaryMapExp<op::minus, TA, ScalarExp<MSHADOW_SCALAR_>,
MSHADOW_SCALAR_, (ta|type::kMapper)>
operator-(const Exp<TA, MSHADOW_SCALAR_, ta> &lhs,
const ScalarExp<MSHADOW_SCALAR_> &rhs) {
return MakeExp<op::minus>(lhs, rhs);
}
/*! \brief operator overload */
template<typename TA, int ta>
inline BinaryMapExp<op::mul, TA, ScalarExp<MSHADOW_SCALAR_>,
MSHADOW_SCALAR_, (ta|type::kMapper)>
operator*(const Exp<TA, MSHADOW_SCALAR_, ta> &lhs,
const ScalarExp<MSHADOW_SCALAR_> &rhs) {
return MakeExp<op::mul>(lhs, rhs);
}
/*! \brief operator overload */
template<typename TA, int ta>
inline BinaryMapExp<op::div, TA, ScalarExp<MSHADOW_SCALAR_>,
MSHADOW_SCALAR_, (ta|type::kMapper)>
operator/(const Exp<TA, MSHADOW_SCALAR_, ta> &lhs,
const ScalarExp<MSHADOW_SCALAR_> &rhs) {
return MakeExp<op::div>(lhs, rhs);
}
// constant operators 2
/*! \brief operator overload */
template<typename TB, int tb>
inline BinaryMapExp<op::plus, ScalarExp<MSHADOW_SCALAR_>, TB,
MSHADOW_SCALAR_, (tb|type::kMapper)>
operator+(const ScalarExp<MSHADOW_SCALAR_> &lhs,
const Exp<TB, MSHADOW_SCALAR_, tb> &rhs) {
return MakeExp<op::plus>(lhs, rhs);
}
/*! \brief operator overload */
template<typename TB, int tb>
inline BinaryMapExp<op::minus, ScalarExp<MSHADOW_SCALAR_>, TB,
MSHADOW_SCALAR_, (tb|type::kMapper)>
operator-(const ScalarExp<MSHADOW_SCALAR_> &lhs,
const Exp<TB, MSHADOW_SCALAR_, tb> &rhs) {
return MakeExp<op::minus>(lhs, rhs);
}
/*! \brief operator overload */
template<typename TB, int tb>
inline BinaryMapExp<op::mul, ScalarExp<MSHADOW_SCALAR_>, TB,
MSHADOW_SCALAR_, (tb|type::kMapper)>
operator*(const ScalarExp<MSHADOW_SCALAR_> &lhs,
const Exp<TB, MSHADOW_SCALAR_, tb> &rhs) {
return MakeExp<op::mul>(lhs, rhs);
}
/*! \brief operator overload */
template<typename TB, int tb>
inline BinaryMapExp<op::div, ScalarExp<MSHADOW_SCALAR_>, TB,
MSHADOW_SCALAR_, (tb|type::kMapper)>
operator/(const ScalarExp<MSHADOW_SCALAR_> &lhs, const Exp<TB, MSHADOW_SCALAR_, tb> &rhs) {
return MakeExp<op::div>(lhs, rhs);
}
// constant operators 3
/*! \brief operator overload */
inline BinaryMapExp<op::plus, ScalarExp<MSHADOW_SCALAR_>, ScalarExp<MSHADOW_SCALAR_>,
MSHADOW_SCALAR_, (1|type::kMapper)>
operator+(const ScalarExp<MSHADOW_SCALAR_> &lhs,
const ScalarExp<MSHADOW_SCALAR_> &rhs) {
return MakeExp<op::plus>(lhs, rhs);
}
/*! \brief operator overload */
inline BinaryMapExp<op::minus, ScalarExp<MSHADOW_SCALAR_>, ScalarExp<MSHADOW_SCALAR_>,
MSHADOW_SCALAR_, (1|type::kMapper)>
operator-(const ScalarExp<MSHADOW_SCALAR_> &lhs,
const ScalarExp<MSHADOW_SCALAR_> &rhs) {
return MakeExp<op::minus>(lhs, rhs);
}
/*! \brief operator overload */
inline BinaryMapExp<op::mul, ScalarExp<MSHADOW_SCALAR_>, ScalarExp<MSHADOW_SCALAR_>,
MSHADOW_SCALAR_, (1|type::kMapper)>
operator*(const ScalarExp<MSHADOW_SCALAR_> &lhs,
const ScalarExp<MSHADOW_SCALAR_> &rhs) {
return MakeExp<op::mul>(lhs, rhs);
}
/*! \brief operator overload */
inline BinaryMapExp<op::div, ScalarExp<MSHADOW_SCALAR_>, ScalarExp<MSHADOW_SCALAR_>,
MSHADOW_SCALAR_, (1|type::kMapper)>
operator/(const ScalarExp<MSHADOW_SCALAR_> &lhs, const ScalarExp<MSHADOW_SCALAR_> &rhs) {
return MakeExp<op::div>(lhs, rhs);
}
} // namespace expr
} // namespace mshadow
#endif // MSHADOW_EXPR_SCALAR_INL_H_
//===== EXPANDED: ../mshadow/mshadow/expr_scalar-inl.h =====
#undef MSHADOW_SCALAR_
#define MSHADOW_SCALAR_ double
#undef MSHADOW_SCALAR_
#define MSHADOW_SCALAR_ int
#undef MSHADOW_SCALAR_
#define MSHADOW_SCALAR_ mshadow::half::half_t
#undef MSHADOW_SCALAR_
#endif // MSHADOW_TENSOR_H_
//===== EXPANDED: ../mshadow/mshadow/tensor.h =====
//===== EXPANDING: ../include/mxnet/base.h =====
/*!
* Copyright (c) 2015 by Contributors
* \file base.h
* \brief configuation of mxnet as well as basic data structure.
*/
#ifndef MXNET_BASE_H_
#define MXNET_BASE_H_
// nnvm headers for symbolic construction.
//===== EXPANDING: ../nnvm/include/nnvm/op.h =====
/*!
* Copyright (c) 2016 by Contributors
* \file op.h
* \brief Operator information structor.
*/
#ifndef NNVM_OP_H_
#define NNVM_OP_H_
//===== EXPANDING: ../nnvm/include/nnvm/base.h =====
/*!
* Copyright (c) 2016 by Contributors
* \file base.h
* \brief Configuation of nnvm as well as basic data structure.
*/
#ifndef NNVM_BASE_H_
#define NNVM_BASE_H_
//===== EXPANDING: ../dmlc-core/include/dmlc/memory.h =====
/*!
* Copyright (c) 2015 by Contributors
* \file memory.h
* \brief Additional memory hanlding utilities.
*/
#ifndef DMLC_MEMORY_H_
#define DMLC_MEMORY_H_
//===== EXPANDING: ../dmlc-core/include/dmlc/thread_local.h =====
/*!
* Copyright (c) 2015 by Contributors
* \file thread_local.h
* \brief Portable thread local storage.
*/
#ifndef DMLC_THREAD_LOCAL_H_
#define DMLC_THREAD_LOCAL_H_
namespace dmlc {
// macro hanlding for threadlocal variables
#ifdef __GNUC__
#define MX_TREAD_LOCAL __thread
#elif __STDC_VERSION__ >= 201112L
#define MX_TREAD_LOCAL _Thread_local
#elif defined(_MSC_VER)
#define MX_TREAD_LOCAL __declspec(thread)
#endif
#ifndef MX_TREAD_LOCAL
#message("Warning: Threadlocal is not enabled");
#endif
/*!
* \brief A threadlocal store to store threadlocal variables.
* Will return a thread local singleton of type T
* \tparam T the type we like to store
*/
template<typename T>
class ThreadLocalStore {
public:
/*! \return get a thread local singleton */
static T* Get() {
static MX_TREAD_LOCAL T* ptr = nullptr;
if (ptr == nullptr) {
ptr = new T();
Singleton()->RegisterDelete(ptr);
}
return ptr;
}
private:
/*! \brief constructor */
ThreadLocalStore() {}
/*! \brief destructor */
~ThreadLocalStore() {
for (size_t i = 0; i < data_.size(); ++i) {
delete data_[i];
}
}
/*! \return singleton of the store */
static ThreadLocalStore<T> *Singleton() {
static ThreadLocalStore<T> inst;
return &inst;
}
/*!
* \brief register str for internal deletion
* \param str the string pointer
*/
void RegisterDelete(T *str) {
std::unique_lock<std::mutex> lock(mutex_);
data_.push_back(str);
lock.unlock();
}
/*! \brief internal mutex */
std::mutex mutex_;
/*!\brief internal data */
std::vector<T*> data_;
};
} // namespace dmlc
#endif // DMLC_THREAD_LOCAL_H_
//===== EXPANDED: ../dmlc-core/include/dmlc/thread_local.h =====
namespace dmlc {
/*!
* \brief A memory pool that allocate memory of fixed size and alignment.
* \tparam size The size of each piece.
* \tparam align The alignment requirement of the memory.
*/
template<size_t size, size_t align>
class MemoryPool {
public:
/*! \brief constructor */
MemoryPool() {
static_assert(align % alignof(LinkedList) == 0,
"alignment requirement failed.");
curr_page_.reset(new Page());
}
/*! \brief allocate a new memory of size */
inline void* allocate() {
if (head_ != nullptr) {
LinkedList* ret = head_;
head_ = head_->next;
return ret;
} else {
if (page_ptr_ < kPageSize) {
return &(curr_page_->data[page_ptr_++]);
} else {
allocated_.push_back(std::move(curr_page_));
curr_page_.reset(new Page());
page_ptr_ = 1;
return &(curr_page_->data[0]);
}
}
}
/*!
* \brief deallocate a piece of memory
* \param p The pointer to the memory to be de-allocated.
*/
inline void deallocate(void* p) {
LinkedList* ptr = static_cast<LinkedList*>(p);
ptr->next = head_;
head_ = ptr;
}
private:
// page size of each member
static const int kPageSize = ((1 << 22) / size);
// page to be requested.
struct Page {
typename std::aligned_storage<size, align>::type data[kPageSize];
};
// internal linked list structure.
struct LinkedList {
LinkedList* next{nullptr};
};
// head of free list
LinkedList* head_{nullptr};
// current free page
std::unique_ptr<Page> curr_page_;
// pointer to the current free page position.
size_t page_ptr_{0};
// allocated pages.
std::vector<std::unique_ptr<Page> > allocated_;
};
/*!
* \brief A thread local allocator that get memory from a threadlocal memory pool.
* This is suitable to allocate objects that do not cross thread.
* \tparam T the type of the data to be allocated.
*/
template<typename T>
class ThreadlocalAllocator {
public:
/*! \brief pointer type */
typedef T* pointer;
/*! \brief const pointer type */
typedef const T* const_ptr;
/*! \brief value type */
typedef T value_type;
/*! \brief default constructor */
ThreadlocalAllocator() {}
/*!
* \brief constructor from another allocator
* \param other another allocator
* \tparam U another type
*/
template<typename U>
ThreadlocalAllocator(const ThreadlocalAllocator<U>& other) {}
/*!
* \brief allocate memory
* \param n number of blocks
* \return an uninitialized memory of type T.
*/
inline T* allocate(size_t n) {
CHECK_EQ(n, 1);
typedef ThreadLocalStore<MemoryPool<sizeof(T), alignof(T)> > Store;
return static_cast<T*>(Store::Get()->allocate());
}
/*!
* \brief deallocate memory
* \param p a memory to be returned.
* \param n number of blocks
*/
inline void deallocate(T* p, size_t n) {
CHECK_EQ(n, 1);
typedef ThreadLocalStore<MemoryPool<sizeof(T), alignof(T)> > Store;
Store::Get()->deallocate(p);
}
};
/*!
* \brief a shared pointer like type that allocate object
* from a threadlocal object pool. This object is not thread-safe
* but can be faster than shared_ptr in certain usecases.
* \tparam T the data type.
*/
template<typename T>
struct ThreadlocalSharedPtr {
public:
/*! \brief default constructor */
ThreadlocalSharedPtr() : block_(nullptr) {}
/*!
* \brief constructor from nullptr
* \param other the nullptr type
*/
ThreadlocalSharedPtr(std::nullptr_t other) : block_(nullptr) {} // NOLINT(*)
/*!
* \brief copy constructor
* \param other another pointer.
*/
ThreadlocalSharedPtr(const ThreadlocalSharedPtr<T>& other)
: block_(other.block_) {
IncRef(block_);
}
/*!
* \brief move constructor
* \param other another pointer.
*/
ThreadlocalSharedPtr(ThreadlocalSharedPtr<T>&& other)
: block_(other.block_) {
other.block_ = nullptr;
}
/*!
* \brief destructor
*/
~ThreadlocalSharedPtr() {
DecRef(block_);
}
/*!
* \brief move assignment
* \param other another object to be assigned.
* \return self.
*/
inline ThreadlocalSharedPtr<T>& operator=(ThreadlocalSharedPtr<T>&& other) {
DecRef(block_);
block_ = other.block_;
other.block_ = nullptr;
return *this;
}
/*!
* \brief copy assignment
* \param other another object to be assigned.
* \return self.
*/
inline ThreadlocalSharedPtr<T> &operator=(const ThreadlocalSharedPtr<T>& other) {
DecRef(block_);
block_ = other.block_;
IncRef(block_);
return *this;
}
/*! \brief check if nullptr */
inline bool operator==(std::nullptr_t other) const {
return block_ == nullptr;
}
/*!
* \return get the pointer content.
*/
inline T* get() const {
if (block_ == nullptr) return nullptr;
return reinterpret_cast<T*>(&(block_->data));
}
/*!
* \brief reset the pointer to nullptr.
*/
inline void reset() {
DecRef(block_);
block_ = nullptr;
}
/*! \return if use_count == 1*/
inline bool unique() const {
if (block_ == nullptr) return false;
return block_->use_count_ == 1;
}
/*! \return dereference pointer */
inline T* operator*() const {
return reinterpret_cast<T*>(&(block_->data));
}
/*! \return dereference pointer */
inline T* operator->() const {
return reinterpret_cast<T*>(&(block_->data));
}
/*!
* \brief create a new space from threadlocal storage and return it.
* \tparam Args the arguments.
* \param args The input argument
* \return the allocated pointer.
*/
template <typename... Args>
inline static ThreadlocalSharedPtr<T> Create(Args&&... args) {
ThreadlocalAllocator<RefBlock> arena;
ThreadlocalSharedPtr<T> p;
p.block_ = arena.allocate(1);
p.block_->use_count_ = 1;
new (&(p.block_->data)) T(std::forward<Args>(args)...);
return p;
}
private:
// internal reference block
struct RefBlock {
typename std::aligned_storage<sizeof(T), alignof(T)>::type data;
unsigned use_count_;
};
// decrease ref counter
inline static void DecRef(RefBlock* block) {
if (block != nullptr) {
if (--block->use_count_ == 0) {
ThreadlocalAllocator<RefBlock> arena;
T* dptr = reinterpret_cast<T*>(&(block->data));
dptr->~T();
arena.deallocate(block, 1);
}
}
}
// increase ref counter
inline static void IncRef(RefBlock* block) {
if (block != nullptr) {
++block->use_count_;
}
}
// internal block
RefBlock *block_;
};
} // namespace dmlc
#endif // DMLC_MEMORY_H_
//===== EXPANDED: ../dmlc-core/include/dmlc/memory.h =====
//===== EXPANDING: ../dmlc-core/include/dmlc/array_view.h =====
/*!
* Copyright (c) 2016 by Contributors
* \file array_view.h
* \brief Read only data structure to reference array
*/
#ifndef DMLC_ARRAY_VIEW_H_
#define DMLC_ARRAY_VIEW_H_
namespace dmlc {
/*!
* \brief Read only data structure to reference continuous memory region of array.
* Provide unified view for vector, array and C style array.
* This data structure do not guarantee aliveness of referenced array.
*
* Make sure do not use array_view to record data in async function closures.
* Also do not use array_view to create reference to temporary data structure.
*
* \tparam ValueType The value
*
* \code
* std::vector<int> myvec{1,2,3};
* dmlc::array_view<int> view(myvec);
* // indexed visit to the view.
* LOG(INFO) << view[0];
*
* for (int v : view) {
* // visit each element in the view
* }
* \endcode
*/
template<typename ValueType>
class array_view {
public:
/*! \brief default constructor */
array_view() = default;
/*!
* \brief default copy constructor
* \param other another array view.
*/
array_view(const array_view<ValueType> &other) = default; // NOLINT(*)
/*!
* \brief default move constructor
* \param other another array view.
*/
array_view(array_view<ValueType>&& other) = default; // NOLINT(*)
/*!
* \brief default assign constructor
* \param other another array view.
* \return self.
*/
array_view<ValueType>& operator=(const array_view<ValueType>& other) = default; // NOLINT(*)
/*!
* \brief construct array view std::vector
* \param other vector container
*/
array_view(const std::vector<ValueType>& other) { // NOLINT(*)
if (other.size() != 0) {
begin_ = &other[0]; size_ = other.size();
}
}
/*!
* \brief construct array std::array
* \param other another array view.
*/
template<std::size_t size>
array_view(const std::array<ValueType, size>& other) { // NOLINT(*)
if (size != 0) {
begin_ = &other[0]; size_ = size;
}
}
/*!
* \brief construct array view from continuous segment
* \param begin beginning pointre
* \param end end pointer
*/
array_view(const ValueType* begin, const ValueType* end) {
if (begin < end) {
begin_ = begin;
size_ = end - begin;
}
}
/*! \return size of the array */
inline size_t size() const {
return size_;
}
/*! \return begin of the array */
inline const ValueType* begin() const {
return begin_;
}
/*! \return end point of the array */
inline const ValueType* end() const {
return begin_ + size_;
}
/*!
* \brief get i-th element from the view
* \param i The index.
* \return const reference to i-th element.
*/
inline const ValueType& operator[](size_t i) const {
return begin_[i];
}
private:
/*! \brief the begin of the view */
const ValueType* begin_{nullptr};
/*! \brief The size of the view */
size_t size_{0};
};
} // namespace dmlc
#endif // DMLC_ARRAY_VIEW_H_
//===== EXPANDED: ../dmlc-core/include/dmlc/array_view.h =====
namespace nnvm {
/*! \brief any type */
using dmlc::any;
/*! \brief array_veiw type */
using dmlc::array_view;
/*!\brief getter function of any type */
using dmlc::get;
} // namespace nnvm
#endif // NNVM_BASE_H_
//===== EXPANDED: ../nnvm/include/nnvm/base.h =====
namespace nnvm {
// forward declarations
class Node;
struct NodeAttrs;
template<typename ValueType>
class OpMap;
class OpGroup;
class OpRegistryEntry;
using dmlc::ParamFieldInfo;
/*! \brief constant to indicate it take any length of positional inputs */
static const uint32_t kVarg = std::numeric_limits<uint32_t>::max();
/*!
* \brief Operator structure.
*
* Besides the fields in the structure,
* arbitary additional information can be associated with each op.
* See function GetAttr for details.
*
* \code
* // Example usage of Op
*
* // registeration of oeprators
* // NOTE that the attr function can register any
* // additional attributes to the operator
* NNVM_REGISTER_OP(add)
* .describe("add two inputs together")
* .set_num_inputs(2)
* .set_attr<OpKernel>("OpKernel<gpu>", AddKernel)
* .include("ElementwiseOpAttr");
*
* // can register attribute by group
* // all the ops that include the group get the attribute.
* NNVM_REGISTER_OP_GROUP(ElementwiseOpAttr)
* .set_attr<FInferShape>("FInferShape", ElementwiseInferShape);
*
* NNVM_REGISTER_OP(sub)
* .describe("substract one tensor from another")
* .set_num_inputs(2);
*
* // Can call regster multiple times in different files
* // to register different part of information
* NNVM_REGISTER_OP(sub)
* .set_attr<OpKernel>("OpKernel<gpu>", SubKernel);
* .include("ElementwiseOpAttr");
*
* // get operators from registry.
* void my_function() {
* const Op* add = Op::Get("add");
* const Op* sub = Op::Get("sub");
* // query basic information about each operator.
* assert(op->name == "plus");
* assert(op->num_inputs == 2);
*
* // get additional registered information,
* // Assume user registered a OpKernel type attribute as gpu_kernel on each operator.
* const OpMap<OpKernel>& kernel = Op::GetAttr<OpKernel>("OpKernel<gpu>");
* // we can get the kernel functions by using operator as key.
* auto add_kernel = kernel[add];
* auto sub_kernel = kernel[sub];
* // subsequent code can make use of the queried kernel functions.
* }
* \endcode
*/
class Op {
public:
/*! \brief name of the operator */
std::string name;
/*!
* \brief detailed description of the operator
* This can be used to generate docstring automatically for the operator.
*/
std::string description;
/* \brief description of inputs and keyword arguments*/
std::vector<ParamFieldInfo> arguments;
/*!
* \brief number of inputs to the operator,
* -1 means it is variable length
* When get_num_inputs is presented,
* the number will be decided by get_num_inputs instead.
* \sa get_num_inputs
*/
uint32_t num_inputs = 1;
/*!
* \brief number of outputs of the operator
* When get_num_outputs is presented.
* The number of outputs will be decided by
* get_num_outputs function
* \sa get_num_outputs
*/
uint32_t num_outputs = 1;
/*!
* \brief get number of outputs given information about the node.
* \param attrs The attribute of the node
* \return number of outputs.
*/
std::function<uint32_t(const NodeAttrs& attrs)> get_num_outputs = nullptr;
/*!
* \brief get number of inputs given information about the node.
* \param attrs The attribute of the node
* \return number of inputs
*/
std::function<uint32_t(const NodeAttrs& attrs)> get_num_inputs = nullptr;
/*!
* \brief Attribute parser to parse the NodeAttrs information.
*
* This can help to get quick access to a parsed attribute
* object
*
* \code
* // Example usage of attr_parser.
*
* // Suppose we want to register operator sum.
* // The parameters about sum operator
* struct SumParam {
* int axis;
* };
* // The parser function
* void SumAttrParser(NodeAttrs* attrs) {
* // This will be invoked during node construction.
* SumParam param;
* // parse axis string to integer
* param.axis = atoi(attrs->dict["axis"].c_str());
* // set the parsed parameter
* attrs->parsed = std::move(param);
* }
* // The other function that can utilize the parsed result.
* TShape SumInferShape(const NodeAttrs& attrs,
* const std::vector<TShape>& ishapes) {
* // we can use the parsed version of param
* // without repeatively parsing the parameter
* const SumParam& param = nnvm::get<SumParam>(attrs.parsed);
* }
* \endcode
*/
std::function<void(NodeAttrs* attrs)> attr_parser = nullptr;
// function fields.
/*!
* \brief setter function during registration
* Set the description of operator
* \param descr the description string.
* \return reference to self.
*/
inline Op& describe(const std::string& descr); // NOLINT(*)
/*!
* \brief Add argument information to the function.
* \param name Name of the argument.
* \param type Type of the argument.
* \param description Description of the argument.
* \return reference to self.
*/
inline Op& add_argument(const std::string &name,
const std::string &type,
const std::string &description);
/*!
* \brief Append list if arguments to the end.
* \param args Additional list of arguments.
* \return reference to self.
*/
inline Op& add_arguments(const std::vector<ParamFieldInfo> &args);
/*!
* \brief Set the num_inputs
* \param n The number of inputs to be set.
* \return reference to self.
*/
inline Op& set_num_inputs(uint32_t n); // NOLINT(*)
/*!
* \brief Set the get_num_outputs function.
* \param fn The function to be set.
* \return reference to self.
*/
inline Op& set_num_inputs(std::function<uint32_t (const NodeAttrs& attr)> fn); // NOLINT(*)
/*!
* \brief Set the num_outputs
* \param n The number of outputs to be set.
* \return reference to self.
*/
inline Op& set_num_outputs(uint32_t n); // NOLINT(*)
/*!
* \brief Set the get_num_outputs function.
* \param fn The function to be set.
* \return reference to self.
*/
inline Op& set_num_outputs(std::function<uint32_t (const NodeAttrs& attr)> fn); // NOLINT(*)
/*!
* \brief Set the attr_parser function.
* \param fn The number of outputs to be set.
* \return reference to self.
*/
inline Op& set_attr_parser(std::function<void (NodeAttrs* attrs)> fn); // NOLINT(*)
/*!
* \brief Register additional attributes to operator.
* \param attr_name The name of the attribute.
* \param value The value to be set.
* \param plevel The priority level of this set,
* an higher priority level attribute
* will replace lower priority level attribute.
* Must be bigger than 0.
*
* Cannot set with same plevel twice in the code.
*
* \tparam ValueType The type of the value to be set.
*/
template<typename ValueType>
inline Op& set_attr(const std::string& attr_name, // NOLINT(*)
const ValueType& value,
int plevel = 10);
/*!
* \brief Add another alias to this operator.
* The same Op can be queried with Op::Get(alias)
* \param alias The alias of the operator.
* \return reference to self.
*/
Op& add_alias(const std::string& alias); // NOLINT(*)
/*!
* \brief Include all the attributes from an registered op group.
* \param group_name The name of the group.
* \return reference to self.
*
* \sa NNVM_REGISTER_OP_GROUP
*/
Op& include(const std::string& group_name);
/*!
* \brief Get an Op for a given operator name.
* Will raise an error if the op has not been registered.
* \param op_name Name of the operator.
* \return Pointer to a Op, valid throughout program lifetime.
*/
static const Op* Get(const std::string& op_name);
/*!
* \brief Get additional registered attribute about operators.
* If nothing has been registered, an empty OpMap will be returned.
* \param attr_name The name of the attribute.
* \return An OpMap of specified attr_name.
* \tparam ValueType The type of the attribute.
*/
template<typename ValueType>
static const OpMap<ValueType>& GetAttr(const std::string& attr_name);
private:
template<typename ValueType>
friend class OpMap;
friend class OpGroup;
friend class dmlc::Registry<Op>;
// Program internal unique index of operator.
// Used to help index the program.
uint32_t index_{0};
// internal constructor
Op();
// get const reference to certain attribute
static const any* GetAttrMap(const std::string& key);
// update the attribute OpMap
static void UpdateAttrMap(const std::string& key,
std::function<void(any*)> updater);
// add a trigger based on tag matching on certain tag attribute
// This will apply trigger on all the op such that
// include the corresponding group.
// The trigger will also be applied to all future registrations
// that calls include
static void AddGroupTrigger(const std::string& group_name,
std::function<void(Op*)> trigger);
};
/*!
* \brief A map data structure that takes Op* as key
* and returns ValueType
* \tparam ValueType The type of the value stored in map.
*/
template<typename ValueType>
class OpMap {
public:
/*!
* \brief get the corresponding value element at op
* \param op The key to the map
* \return the const reference to the content value.
*/
inline const ValueType& operator[](const Op* op) const;
/*!
* \brief get the corresponding value element at op with default value.
* \param op The key to the map
* \param def_value The default value when the key does not exist.
* \return the const reference to the content value.
*/
inline const ValueType& get(const Op* op, const ValueType& def_value) const;
/*!
* \brief Check if the map has op as key.
* \param op The key to the map
* \return 1 if op is contained in map, 0 otherwise.
*/
inline int count(const Op* op) const;
private:
friend class Op;
// internal attribute name
std::string attr_name_;
// internal data
std::vector<std::pair<ValueType, int> > data_;
OpMap() = default;
};
/*!
* \brief auxiliary data structure used to
* set attributes to a group of operators
*/
class OpGroup {
public:
/*! \brief the tag key to be matched */
std::string group_name;
/*!
* \brief Register additional attributes to operator group.
* \param attr_name The name of the attribute.
* \param value The value to be set.
* \param plevel The priority level of this set,
* an higher priority level attribute
* will replace lower priority level attribute.
* Must be bigger than 0.
*
* Cannot set with same plevel twice in the code.
*
* \tparam ValueType The type of the value to be set.
*/
template<typename ValueType>
inline OpGroup& set_attr(const std::string& attr_name, // NOLINT(*)
const ValueType& value,
int plevel = 1);
};
// internal macros to make
#define NNVM_REGISTER_VAR_DEF(OpName) \
static DMLC_ATTRIBUTE_UNUSED ::nnvm::Op & __make_ ## NnvmOp ## _ ## OpName
#define NNVM_REGISTER_GVAR_DEF(TagName) \
static DMLC_ATTRIBUTE_UNUSED ::nnvm::OpGroup __make_ ## NnvmOpGroup ## _ ## TagName
/*!
* \def NNVM_REGISTER_OP
* \brief Register a new operator, or set attribute of the corresponding op.
*
* \param OpName The name of registry
*
* \code
*
* NNVM_REGISTER_OP(add)
* .describe("add two inputs together")
* .set_num_inputs(2)
* .set_attr<OpKernel>("gpu_kernel", AddKernel);
*
* \endcode
*/
#define NNVM_REGISTER_OP(OpName) \
DMLC_STR_CONCAT(NNVM_REGISTER_VAR_DEF(OpName), __COUNTER__) = \
::dmlc::Registry<::nnvm::Op>::Get()->__REGISTER_OR_GET__(#OpName)
/*!
* \def NNVM_REGISTER_OP_GROUP
* \brief Register attribute to a group of operators.
* These attributes will be registered to Op that include the group.
*
* \param GroupName The name of the group.
*
* \code
*
* NNVM_REGISTER_OP(add)
* .include("ElementwiseOpAttr");
*
* // register same attributes to all the ops that include the group
* NNVM_REGISTER_OP_GROUP(ElementwiseOpAttr)
* .set_attr<FInferShape>("FInferShape", ElementwiseInferShape);
*
* NNVM_REGISTER_OP(mul)
* .include("ElementwiseOpAttr");
*
* \endcode
*/
#define NNVM_REGISTER_OP_GROUP(GroupName) \
DMLC_STR_CONCAT(NNVM_REGISTER_GVAR_DEF(GroupName), __COUNTER__) = \
::nnvm::OpGroup {#GroupName}
// implementations of template functions after this.
// member function of Op
template<typename ValueType>
inline const OpMap<ValueType>& Op::GetAttr(const std::string& key) {
const any* ref = GetAttrMap(key);
if (ref == nullptr) {
// update the attribute map of the key by creating new empty OpMap
UpdateAttrMap(key, [key](any* pmap) {
// use callback so it is in lockscope
if (pmap->empty()) {
OpMap<ValueType> pm;
pm.attr_name_ = key;
*pmap = std::move(pm);
}
});
ref = GetAttrMap(key);
}
return nnvm::get<OpMap<ValueType> >(*ref);
}
template<typename ValueType>
inline Op& Op::set_attr( // NOLINT(*)
const std::string& attr_name,
const ValueType& value,
int plevel) {
CHECK_GT(plevel, 0)
<< "plevel in set_attr must be greater than 0";
// update the attribute map of the key by creating new empty if needed.
UpdateAttrMap(attr_name,
[this, attr_name, value, plevel](any* pmap) {
// the callback is in lockscope so is threadsafe.
if (pmap->empty()) {
OpMap<ValueType> pm;
pm.attr_name_ = attr_name;
*pmap = std::move(pm);
}
CHECK(pmap->type() == typeid(OpMap<ValueType>))
<< "Attribute " << attr_name
<< " of operator " << this->name
<< " is registered as inconsistent types"
<< " previously " << pmap->type().name()
<< " current " << typeid(OpMap<ValueType>).name();
std::vector<std::pair<ValueType, int> >& vec =
nnvm::get<OpMap<ValueType> >(*pmap).data_;
// resize the value type.
if (vec.size() <= index_) {
vec.resize(index_ + 1,
std::make_pair(ValueType(), 0));
}
std::pair<ValueType, int>& p = vec[index_];
CHECK(p.second != plevel)
<< "Attribute " << attr_name
<< " of operator " << this->name
<< " is already registered with same plevel=" << plevel;
if (p.second < plevel) {
vec[index_] = std::make_pair(value, plevel);
}
});
return *this;
}
inline Op& Op::describe(const std::string& descr) { // NOLINT(*)
this->description = descr;
return *this;
}
inline Op& Op::add_argument(const std::string &name,
const std::string &type,
const std::string &description) {
arguments.push_back({name, type, type, description});
return *this;
}
inline Op& Op::add_arguments(const std::vector<ParamFieldInfo> &args) {
this->arguments.insert(arguments.end(), args.begin(), args.end());
return *this;
}
inline Op& Op::set_num_inputs(uint32_t n) { // NOLINT(*)
this->num_inputs = n;
return *this;
}
inline Op& Op::set_num_inputs(std::function<uint32_t (const NodeAttrs& attr)> fn) { // NOLINT(*)
this->get_num_inputs = fn;
return *this;
}
inline Op& Op::set_num_outputs(uint32_t n) { // NOLINT(*)
this->num_outputs = n;
return *this;
}
inline Op& Op::set_num_outputs(std::function<uint32_t (const NodeAttrs& attr)> fn) { // NOLINT(*)
this->get_num_outputs = fn;
return *this;
}
inline Op& Op::set_attr_parser(std::function<void (NodeAttrs* attrs)> fn) { // NOLINT(*)
this->attr_parser = fn;
return *this;
}
// member functions of OpMap
template<typename ValueType>
inline int OpMap<ValueType>::count(const Op* op) const {
if (op == nullptr) return 0;
const uint32_t idx = op->index_;
return idx < data_.size() ? (data_[idx].second != 0) : 0;
}
template<typename ValueType>
inline const ValueType& OpMap<ValueType>::operator[](const Op* op) const {
CHECK(op != nullptr);
const uint32_t idx = op->index_;
CHECK(idx < data_.size() && data_[idx].second)
<< "Attribute " << attr_name_
<< " has not been registered for Operator " << op->name;
return data_[idx].first;
}
template<typename ValueType>
inline const ValueType& OpMap<ValueType>::get(const Op* op, const ValueType& def_value) const {
if (op == nullptr) return def_value;
const uint32_t idx = op->index_;
if (idx < data_.size() && data_[idx].second) {
return data_[idx].first;
} else {
return def_value;
}
}
template<typename ValueType>
inline OpGroup& OpGroup::set_attr(const std::string& attr_name,
const ValueType& value,
int plevel) {
auto trigger = [attr_name, value, plevel](Op* op) {
op->set_attr<ValueType>(attr_name, value, plevel);
};
Op::AddGroupTrigger(group_name, trigger);
return *this;
}
} // namespace nnvm
#endif // NNVM_OP_H_
//===== EXPANDED: ../nnvm/include/nnvm/op.h =====
//===== EXPANDING: ../nnvm/include/nnvm/tuple.h =====
/*!
* Copyright (c) 2016 by Contributors
* \file tuple.h
* \brief Data structure Tuple and TShape to store dynamic sized shapes.
*/
#ifndef NNVM_TUPLE_H_
#define NNVM_TUPLE_H_
namespace nnvm {
/*! \brief data type to store array index */
typedef uint32_t index_t;
/*!
* \brief A dynamic sized array data strcuture that is optimized for storing
* small number of elements with same type.
*
* Data will be stored in stack when number of elements is small.
* It is suitable to hold shape of Tensor.
*
* \tparam ValueType The type of data stored inside tuple.
* \sa TShape
*/
template<typename ValueType>
class Tuple {
public:
// Tuple requires the content to be simple data type.
static_assert(std::is_pod<ValueType>::value,
"Tuple only support simple data type like int");
/*! \brief default constructor */
Tuple() = default;
/*! \brief destructor */
inline ~Tuple() {
delete [] data_heap_;
}
/*!
* \brief copy constructor from another tuple
* \param s the source tuple
*/
inline Tuple(const Tuple<ValueType>& s) {
this->assign(s.begin(), s.end());
}
/*!
* \brief constructor from initializer list
* \param init the initializer_list
*/
inline Tuple(std::initializer_list<ValueType> init) {
this->assign(init.begin(), init.end());
}
/*!
* \brief move constructor from Tuple
* \param src the source shape
*/
inline Tuple(Tuple<ValueType>&& src) { // NOLINT(*)
this->swap(src);
}
/*!
* \brief construct the Tuple from content of iterator
* \param begin the beginning of iterator
* \param end end the end of the iterator
* \tparam RandomAccessIterator iterator type
*/
template<typename RandomAccessIterator>
inline Tuple(RandomAccessIterator begin,
RandomAccessIterator end) {
this->assign(begin, end);
}
/*!
* \brief Assign content to tuple from iterator.
* \param begin the beginning of iteratro
* \param end end the end of the iterator
* \tparam RandomAccessIterator iterator type
*/
template<typename RandomAccessIterator>
inline void assign(RandomAccessIterator begin,
RandomAccessIterator end) {
this->SetDim(end - begin);
std::copy(begin, end, this->begin());
}
/*!
* \brief Swap current object with other
* \param other another object to be swapped.
*/
inline void swap(Tuple<ValueType>& other) { // NOLINT(*)
std::swap(ndim_, other.ndim_);
std::swap(num_heap_allocated_, other.num_heap_allocated_);
std::swap(data_stack_, other.data_stack_);
std::swap(data_heap_, other.data_heap_);
}
/*!
* \brief assignment from another tuple.
* \param src source tuple
* \return reference of self
*/
inline Tuple<ValueType>& operator=(const Tuple<ValueType>& src) {
this->assign(src.begin(), src.end());
return *this;
}
/*!
* \brief assignment from rvalue of another tuple.
* \param src source tuple
* \return reference of self
*/
inline Tuple<ValueType>& operator=(Tuple<ValueType>&& src) {
Tuple<ValueType>(std::move(src)).swap(*this);
return *this;
}
/*!
* \brief assignment from initializer list
* \param init the source initializer list
* \return reference of self
*/
inline Tuple<ValueType> &operator=(std::initializer_list<ValueType> init) {
this->assign(init.begin(), init.end());
return *this;
}
/*!
* \return whether two tuple equals
* \param s the tuple to compare against
*/
inline bool operator==(const Tuple<ValueType> &s) const {
if (ndim_ != s.ndim_) return false;
return std::equal(begin(), end(), s.begin());
}
/*!
* \return whether two tuple not equal
* \param s the tuple to compare against
*/
inline bool operator!=(const Tuple<ValueType> &s) const {
return !(*this == s);
}
/*! \return the begin data pointer to content of the tuple */
inline const ValueType *begin() const {
return ndim_ <= kStackCache ? data_stack_ : data_heap_;
}
/*! \return the begin data pointer to content of the tuple */
inline ValueType *begin() {
return ndim_ <= kStackCache ? data_stack_ : data_heap_;
}
/*! \return the data pointer to end of the tuple */
inline const ValueType* end() const {
return ndim_ <= kStackCache ? (data_stack_ + ndim_): (data_heap_ + ndim_);
}
/*! \return the data pointer to end the tuple */
inline ValueType* end() {
return ndim_ <= kStackCache ? (data_stack_ + ndim_): (data_heap_ + ndim_);
}
/*! \return number of dimension of the tuple */
inline index_t ndim() const {
return ndim_;
}
/*!
* \brief get corresponding index
* \param i dimension index
* \return the corresponding dimension size
*/
inline ValueType& operator[](index_t i) {
return begin()[i];
}
/*!
* \brief get corresponding index
* \param i dimension index
* \return the corresponding dimension size
*/
inline const ValueType& operator[](index_t i) const {
return begin()[i];
}
/*!
* \brief Save Tuple to JSON.
* \param writer JSONWriter
*/
inline void Save(dmlc::JSONWriter* writer) const {
std::vector<ValueType> tmp(begin(), end());
writer->Write(tmp);
}
/*!
* \brief Load Tuple from JSON.
* \param reader JSONReader
*/
inline void Load(dmlc::JSONReader* reader) {
std::vector<ValueType> tmp;
reader->Read(&tmp);
this->assign(tmp.begin(), tmp.end());
}
/*!
* \brief allow output string of tuple to ostream
* \param os the output stream
* \param t the tuple
* \return the ostream
*/
friend std::ostream &operator<<(std::ostream &os, const Tuple<ValueType> &t) {
os << '(';
const ValueType* begin = t.begin();
const ValueType* end = t.end();
for (const ValueType* it = begin; it != end; ++it) {
if (it != begin) os << ',';
os << *it;
}
// python style tuple
if (t.ndim() == 1) os << ',';
os << ')';
return os;
}
/*!
* \brief read tuple from the istream
* \param is the input stream
* \param t The tuple
* \return the istream
*/
friend std::istream &operator>>(std::istream &is, Tuple<ValueType> &t) {
// get (
while (true) {
char ch = is.peek();
if (isdigit(ch)) {
ValueType idx;
if (is >> idx) {
t.assign(&idx, &idx + 1);
}
return is;
}
is.get();
if (ch == '(' || ch == '[') break;
if (!isspace(ch)) {
is.setstate(std::ios::failbit);
return is;
}
}
// Handle empty tuple
while (isspace(is.peek())) {
is.get();
}
if (is.peek() == ')') {
is.get();
return is;
}
// Handle non-empty tuple
ValueType idx;
std::vector<ValueType> tmp;
while (is >> idx) {
tmp.push_back(idx);
char ch;
do {
ch = is.get();
} while (isspace(ch));
if (std::is_integral<ValueType>::value && ch == 'L') {
ch = is.get();
}
if (ch == ',') {
while (true) {
ch = is.peek();
if (isspace(ch)) {
is.get(); continue;
}
if (ch == ')' || ch == ']') {
is.get(); break;
}
break;
}
if (ch == ')' || ch == ']') break;
} else if (ch == ')' || ch == ']') {
break;
} else {
is.setstate(std::ios::failbit);
return is;
}
}
t.assign(tmp.begin(), tmp.end());
return is;
}
/*!
* \brief save the content into binary stream
* \param strm the output stream
* \tparam TStream any stream type that have write
*/
template<typename TStream>
inline void Save(TStream *strm) const {
strm->Write(&ndim_, sizeof(ndim_));
strm->Write(begin(), sizeof(ValueType) * ndim_);
}
/*!
* \brief load the content from binary stream
* \param strm the output stream
* \tparam TStream any stream type that have write
* \return whether the load is successful
*/
template<typename TStream>
inline bool Load(TStream *strm) {
if (strm->Read(&ndim_, sizeof(ndim_)) != sizeof(ndim_)) return false;
this->SetDim(ndim_);
size_t nread = sizeof(ValueType) * ndim_;
if (strm->Read(begin(), nread) != nread) return false;
return true;
}
protected:
// stack cache size
static const uint32_t kStackCache = 4;
/*! \brief number of dimension of the tuple */
index_t ndim_{0};
/*! \brief number of cells allocated in data_heap_ */
index_t num_heap_allocated_{0};
/*! \brief in stack space used to store shape when it is small */
ValueType data_stack_[kStackCache];
/*! \brief space to store shape when dimension is big*/
ValueType* data_heap_{nullptr};
// internal function to change the dimension
inline void SetDim(index_t dim) {
if (dim > kStackCache &&
dim > num_heap_allocated_) {
delete [] data_heap_;
data_heap_ = new ValueType[dim];
num_heap_allocated_ = dim;
}
ndim_ = dim;
}
};
/*!
* \brief A Shape class that is used to represent shape of each tensor.
*/
class TShape : public Tuple<index_t> {
public:
/*! \brief default constructor */
TShape() = default;
/*!
* constructor to construct a shape with all 1.
* \param ndim the number of dimension
*/
inline TShape(index_t ndim) { // NOLINT(*)
this->SetDim(ndim);
std::fill_n(begin(), ndim, 1);
}
/*!
* \brief copy constructor of TShape
* \param s source shape.
*/
inline TShape(const Tuple<index_t>& s) { // NOLINT(*)
this->assign(s.begin(), s.end());
}
/*!
* \brief constructor from initializer list
* \param init the initializer_list
*/
inline TShape(std::initializer_list<index_t> init) {
this->assign(init.begin(), init.end());
}
/*!
* \brief move constructor.
* \param s source shape.
*/
inline TShape(Tuple<index_t>&& s) { // NOLINT(*)
this->swap(s);
}
/*!
* \brief construct the Tuple from content of iterator
* \param begin the beginning of iterator
* \param end end the end of the iterator
* \tparam RandomAccessIterator iterator type
*/
template<typename RandomAccessIterator>
inline TShape(RandomAccessIterator begin,
RandomAccessIterator end) {
this->assign(begin, end);
}
/*!
* \brief assignment function from tshape
* \param src source shape.
* \return self.
*/
inline TShape& operator=(const Tuple<index_t>& src) {
this->assign(src.begin(), src.end());
return *this;
}
/*!
* \brief move assignment function from tshape
* \param src source shape.
* \return self.
*/
inline TShape& operator=(Tuple<index_t>&& src) { // NOLINT(*)
TShape(std::move(src)).swap(*this); // NOLINT(*)
return *this;
}
/*! \return total number of elements in the shape */
inline size_t Size() const {
size_t size = 1;
const index_t* start = begin(), *fin = end();
for (const index_t* it = start; it != fin; ++it) {
size *= *it;
}
return size;
}
/*!
* \return product shape in [dimstart,dimend)
* \param dimstart start dimension
* \param dimend end dimension
*/
inline index_t ProdShape(int dimstart, int dimend) const {
index_t num = 1;
const index_t *d = this->data();
for (int i = dimstart; i < dimend; ++i) {
num *= d[i];
}
return num;
}
/*! \return the begin data pointer to content of the tuple */
inline const index_t *data() const {
return begin();
}
/*! \return the begin data pointer to content of the tuple */
inline index_t *data() {
return begin();
}
#ifdef MSHADOW_XINLINE
template<int dim>
inline TShape(const mshadow::Shape<dim> &s) {// NOLINT(*)
this->assign(s.shape_, s.shape_ + dim);
}
template<int dim>
inline TShape(mshadow::Shape<dim> &&s) {// NOLINT(*)
this->assign(s.shape_, s.shape_ + dim);
}
/*!
* \brief assignment from shape
* \param shape source shape
* \tparam dim shape dimension
* \return reference of self
*/
template<int dim>
inline TShape &operator=(const mshadow::Shape<dim> &shape) {
this->assign(shape.shape_, shape.shape_ + dim);
return *this;
}
/*!
* \brief get the shape of tensor specifying dim
* \return the shape requested
* \tparam dim dimension of the tensor
*/
template<int dim>
inline mshadow::Shape<dim> get() const {
CHECK_EQ(dim, ndim())
<< "dimension do not match target dimension " << dim << " vs " << ndim();
const index_t *d = this->data();
mshadow::Shape<dim> s;
for (int i = 0; i < dim; ++i) {
s[i] = d[i];
}
return s;
}
/*!
* flatten the higher dimension to second dimension, return a 2D shape
* \return the flat 2d shape
*/
inline mshadow::Shape<2> FlatTo2D(void) const {
mshadow::Shape<2> s;
if (ndim() == 0) return mshadow::Shape2(0, 0);
const index_t *d = this->data();
s.shape_[1] = d[ndim() - 1];
index_t ymax = 1;
for (index_t i = 1; i < ndim(); ++i) {
ymax *= d[i - 1];
}
s.shape_[0] = ymax;
return s;
}
/*!
* flatten the shape into three parts: [0, axis_begin), [axis_begin, axis_end], (axis_end, ndim)
* \param axis_begin The beginning axis specified.
* \param axis_end The ending axis specified.
* \return the flat 3d shape
*/
inline mshadow::Shape<3> FlatTo3D(index_t axis_begin, index_t axis_end) const {
CHECK(axis_end >= axis_begin);
mshadow::Shape<3> s;
if (ndim() == 0) return mshadow::Shape3(0, 0, 0);
const index_t *d = this->data();
s.shape_[0] = 1;
s.shape_[1] = 1;
s.shape_[2] = 1;
for (index_t i = 0; i < axis_begin; ++i) {
s.shape_[0] *= d[i];
}
for (index_t i = axis_begin; i <= axis_end; ++i) {
s.shape_[1] *= d[i];
}
for (index_t i = axis_end + 1; i < ndim(); ++i) {
s.shape_[2] *= d[i];
}
return s;
}
/*!
* flatten the axis before and after the specified axis, so it becomes 3D tensor
* \param axis The axis specified.
* \return the flat 3d shape
*/
inline mshadow::Shape<3> FlatTo3D(index_t axis) const {
return FlatTo3D(axis, axis);
}
inline bool operator==(const TShape &s) const {
if (ndim() != s.ndim()) return false;
return std::equal(begin(), end(), s.begin());
}
inline bool operator!=(const TShape &s) const {
return !(*this == s);
}
/*!
* \return whether two shape equals
* \param s the shape to compare against
* \tparam dim dimension of the shape
*/
template<int dim>
inline bool operator==(const mshadow::Shape<dim> &s) const {
if (ndim_ != dim) return false;
const index_t *d = dim <= kStackCache ? data_stack_ : data_heap_;
for (index_t i = 0; i < dim; ++i) {
if (d[i] != s.shape_[i]) return false;
}
return true;
}
/*!
* \return whether two shape not equals
* \param s the shape to compare against
* \tparam dim dimension of the shape
*/
template<int dim>
inline bool operator!=(const mshadow::Shape<dim> &s) const {
return !(*this == s);
}
#endif
};
} // namespace nnvm
#endif // NNVM_TUPLE_H_
//===== EXPANDED: ../nnvm/include/nnvm/tuple.h =====
//===== EXPANDING: ../nnvm/include/nnvm/symbolic.h =====
/*!
* Copyright (c) 2016 by Contributors
* \file symbolic.h
* \brief Symbolic graph construction API
*
* This API is optional, but useful to allow user
* to construct NNVM Graph easily, and quickly create
* front-end host languages.
*/
#ifndef NNVM_SYMBOLIC_H_
#define NNVM_SYMBOLIC_H_
//===== EXPANDING: ../nnvm/include/nnvm/node.h =====
/*!
* Copyright (c) 2016 by Contributors
* \file node.h
* \brief Graph node data structure.
*/
#ifndef NNVM_NODE_H_
#define NNVM_NODE_H_
namespace nnvm {
// Forward declare node.
class Node;
/*!
* \brief we always used NodePtr for a reference pointer
* to the node, so this alias can be changed in case.
*
* By default, NodePtr is a std::shared_ptr of node
*/
using NodePtr = std::shared_ptr<Node>;
/*! \brief an entry that represents output data from a node */
struct NodeEntry {
/*! \brief the source node of this data */
NodePtr node;
/*! \brief index of output from the source. */
uint32_t index;
/*!
* \brief version of input Variable.
* This field can only be nonzero when this->node is a Variable node.
* version is increased by one each time a Variable get composed to a mutation Op.
* This information can be helpful to decide order of operations when sequence of mutation happens.
*/
uint32_t version;
};
/*!
* \brief The attributes of the current operation node.
* Usually are additional parameters like axis,
*/
struct NodeAttrs {
/*!
* \brief The operator this node uses.
* For place holder variable, op == nullptr.
*/
const Op *op{nullptr};
/*! \brief name of the node */
std::string name;
/*! \brief Vector representation of positional attributes */
std::vector<double> scalars;
/*! \brief The dictionary representation of attributes */
std::unordered_map<std::string, std::string> dict;
/*!
* \brief A parsed version of attributes,
* This is generated if OpProperty.attr_parser is registered.
* The object can be used to quickly access attributes.
*/
any parsed;
};
/*!
* \brief Node represents an operation in a computation graph.
*/
class Node {
public:
/*! \brief The attributes in the node. */
NodeAttrs attrs;
/*! \brief inputs to this node */
std::vector<NodeEntry> inputs;
/*!
* \brief Optional control flow dependencies
* Gives operation must be performed before this operation.
*/
std::vector<NodePtr> control_deps;
/*! \brief destructor of node */
~Node();
/*! \return operator in this node */
inline const Op* op() const;
/*!
* \brief return whether node is placeholder variable.
* This is equivalent to op == nullptr
* \return whether node is placeholder input variable
*/
inline bool is_variable() const;
/*! \return number of outputs from this node */
inline uint32_t num_outputs() const;
/*! \return number of inputs from this node */
inline uint32_t num_inputs() const;
/*!
* \brief create a new empty shared_ptr of Node.
* \return a created empty node.
*/
static NodePtr Create();
};
// implementation of functions.
inline const Op* Node::op() const {
return this->attrs.op;
}
inline bool Node::is_variable() const {
return this->op() == nullptr;
}
inline uint32_t Node::num_outputs() const {
if (is_variable()) return 1;
if (this->op()->get_num_outputs == nullptr) {
return this->op()->num_outputs;
} else {
return this->op()->get_num_outputs(this->attrs);
}
}
inline uint32_t Node::num_inputs() const {
if (is_variable()) return 1;
if (this->op()->get_num_inputs == nullptr) {
return this->op()->num_inputs;
} else {
return this->op()->get_num_inputs(this->attrs);
}
}
} // namespace nnvm
#endif // NNVM_NODE_H_
//===== EXPANDED: ../nnvm/include/nnvm/node.h =====
namespace nnvm {
/*!
* \brief Symbol is help class used to represent the operator node in Graph.
*
* Symbol acts as an interface for building graphs from different components
* like Variable, Functor and Group. Symbol is also exported to python front-end
* (while Graph is not) to enable quick test and deployment. Conceptually,
* symbol is the final operation of a graph and thus including all the information
* required (the graph) to evaluate its output value.
*/
class Symbol {
public:
/*! \brief option passed to ListAttr */
enum ListAttrOption {
/*! \brief recursively list all attributes */
kRecursive = 0,
/*! \brief only list attributes in current node */
kShallow = 1
};
/*! \brief option passed to ListInputNames */
enum ListInputOption {
/*! \brief list all the arguments */
kAll = 0,
/*! \brief list only read only arguments */
kReadOnlyArgs = 1,
/*!
* \brief List auxiliary states that can be mutated by the graph.
* This excludes the ReadOnly arguments
*/
kAuxiliaryStates = 2
};
/*! \brief output entries contained in the symbol */
std::vector<NodeEntry> outputs;
/*!
* \brief Copy the symbol.
* \return A deep copy of this symbol.
*/
Symbol Copy() const;
/*!
* \brief Print the symbol info to output stream.
* \param os The output stream to print to.
*/
void Print(std::ostream &os) const; // NOLINT(*)
/*!
* \brief Get the index-th element from the returned tuple.
* \param index Index of multi output.
* \return The symbol corresponds to the indexed element.
*/
Symbol operator[] (size_t index) const;
/*!
* \brief List the input variable nodes.
*
* The order of the returned list is the same as the order of the input list to `operator()`.
*
* \param option The options to list the arguments.
* \return The arguments list of this symbol, they can be either named or unnamed (empty string).
* \sa ListInputOption
*/
std::vector<NodePtr> ListInputs(ListInputOption option) const;
/*!
* \brief List the input names.
*
* The order of the returned list is the same as the order of the input list to `operator()`.
*
* \param option The options to list the arguments.
* \return The arguments list of this symbol, they can be either named or unnamed (empty string).
* \sa ListInputOption
*/
std::vector<std::string> ListInputNames(ListInputOption option) const;
/*!
* \brief List the names of outputs for this symbol.
*
* For normal operators, it is usually symbol node name + "_output".
*
* \return get the descriptions of outputs for this symbol.
*/
std::vector<std::string> ListOutputNames() const;
/*!
* \brief Compose the symbol with arguments, this changes the current symbol.
* The kwargs passed in can be in-complete,
*
* The rest of the symbols will remain the same name.
*
* \param args Positional arguments.
* \param kwargs Keyword arguments for the symbol.
* \param name Name of returned symbol.
*/
void Compose(const array_view<const Symbol*>& args,
const std::unordered_map<std::string, const Symbol*>& kwargs,
const std::string& name);
/*!
* \brief Apply the symbol as a function, compose with arguments
*
* This is equivalent to Copy then Compose.
*
* \param args Positional arguments for the symbol.
* \param kwargs Keyword arguments for the symbol.
* \param name Name of returned symbol.
* \return A new Symbol which is the composition of current symbol with its arguments.
*/
Symbol operator () (const array_view<const Symbol*>& args,
const std::unordered_map<std::string, const Symbol*>& kwargs,
const std::string& name) const;
/*!
* \brief Add control flow depenencies to the operators in symbols.
*
* For grouped symbol, an error will be raised. This mutates current symbolic Node.
*
* \param src The symbols to depend on.
*/
void AddControlDeps(const Symbol& src);
/*
* \brief Get all the internal nodes of the symbol.
* \return symbol A new symbol whose output contains all the outputs of the symbols
* including input variables and intermediate outputs.
*/
Symbol GetInternals() const;
/*!
* \brief Set additional attributes to current node.
*
* This only works for symbol with outputs from single operators.
* For grouped symbol, an error will be raised.
*
* This function mutates the node's symbol and is not recommended.
*
* \param attrs The attributes to set.
*/
void SetAttrs(const std::vector<std::pair<std::string, std::string> >& attrs);
/*!
* \brief Get attributes from the symbol.
*
* This only works for symbol with outputs from single operators.
* For grouped symbol, an error will be raised.
*
* \param key Key of the attribute. When key == "name", it returns the name attirbute.
* \param out The output value of the attribute.
* \return true If the attribute exists, false if the attribute does not exist.
*/
bool GetAttr(const std::string& key, std::string* out) const;
/*!
* \brief Get attribute dictionary from the symbol.
*
* For grouped symbol, an error will be raised.
*
* \param option If recursive flag is set, the attributes of all children are retrieved.
* The name of symbol will be pre-pended to each key.
* \return The created attribute.
*/
std::unordered_map<std::string, std::string> ListAttrs(ListAttrOption option) const;
/*!
* \brief Get attribute dictionary from the symbol and all children.
*
* For grouped symbol, an error will be raised.
*
* \return The created attribute in format <operator_name, key, value>.
*/
std::vector<std::tuple<std::string, std::string, std::string> >
ListAttrsRecursive() const;
/*!
* \brief Create symbolic functor(AtomicSymbol) by given operator and attributes.
* \param op The operator.
* \param attrs The additional attributes.
* \return Symbol that can be used to call compose further.
*/
static Symbol CreateFunctor(const Op* op,
std::unordered_map<std::string, std::string> attrs);
/*!
* \brief Create symbol node representing variable.
* \param name Name of the variable.
* \return The symbol.
*/
static Symbol CreateVariable(const std::string& name);
/*!
* \brief Create equivalence of symbol by grouping the symbols together.
* \param symbols A list of symbols to be grouped.
* \return The grouped symbol.
*/
static Symbol CreateGroup(const std::vector<Symbol>& symbols);
};
} // namespace nnvm
#endif // NNVM_SYMBOLIC_H_
//===== EXPANDED: ../nnvm/include/nnvm/symbolic.h =====
/*!
*\brief whether to use opencv support
*/
#ifndef MXNET_USE_OPENCV
#define MXNET_USE_OPENCV 1
#endif
/*!
*\brief whether to use cuda support
*/
#ifndef MXNET_USE_CUDA
#define MXNET_USE_CUDA MSHADOW_USE_CUDA
#endif
/*!
*\brief whether to use cudnn library for convolution
*/
#ifndef MXNET_USE_CUDNN
#define MXNET_USE_CUDNN MSHADOW_USE_CUDNN
#endif
/*! \brief Error message for using gpu when MXNET_USE_CUDA==0 */
#define MXNET_GPU_NOT_ENABLED_ERROR "GPU is not enabled"
/*!
* \brief define compatible keywords in g++
* Used to support g++-4.6 and g++4.7
*/
#if DMLC_USE_CXX11 && defined(__GNUC__) && !defined(__clang_version__)
#if __GNUC__ == 4 && __GNUC_MINOR__ < 8
#error "Currently we need g++ 4.8 or higher to fully support c++11 features"
#define override
#define final
#endif
#endif
/*!
* \brief define dllexport for Visual Studio
*/
#ifdef _MSC_VER
#ifdef MXNET_EXPORTS
#define MXNET_API __declspec(dllexport)
#else
#define MXNET_API __declspec(dllimport)
#endif
#else
#define MXNET_API
#endif
/*!
* \brief define prediction only
*/
#ifndef MXNET_PREDICT_ONLY
#define MXNET_PREDICT_ONLY 0
#endif
/*!
* \brief define operator message for profiler
*/
#if MXNET_USE_PROFILER
#define PROFILER_MESSAGE(msg) msg
#else
#define PROFILER_MESSAGE(msg) nullptr
#endif
/*! \brief major version */
#define MXNET_MAJOR 0
/*! \brief minor version */
#define MXNET_MINOR 9
/*! \brief patch version */
#define MXNET_PATCH 3
/*! \brief mxnet version */
#define MXNET_VERSION (MXNET_MAJOR*10000 + MXNET_MINOR*100 + MXNET_PATCH)
/*! \brief helper for making version number */
#define MXNET_MAKE_VERSION(major, minor, patch) ((major)*10000 + (minor)*100 + patch)
/*!
* \brief define function name as profiler message
*/
#define PROFILER_MESSAGE_FUNCNAME PROFILER_MESSAGE(__FUNCTION__)
/*! \brief namespace of mxnet */
namespace mxnet {
/*! \brief mxnet cpu */
typedef mshadow::cpu cpu;
/*! \brief mxnet gpu */
typedef mshadow::gpu gpu;
/*! \brief index type usually use unsigned */
typedef mshadow::index_t index_t;
/*! \brief data type that will be used to store ndarray */
typedef mshadow::default_real_t real_t;
/*! \brief Shape data structure used to record shape information */
using TShape = nnvm::TShape;
/*! \brief operator structure from NNVM */
using Op = nnvm::Op;
/*! \brief Context information about the execution environment */
struct Context {
/*! \brief Type of device */
enum DeviceType {
kCPU = cpu::kDevMask,
kGPU = gpu::kDevMask,
kCPUPinned = 3
};
/*! \brief the device type we run the op on */
DeviceType dev_type;
/*! \brief device id we are going to run it on */
int32_t dev_id;
/*! \brief default constructor */
Context() : dev_type(kCPU), dev_id(0) {}
/*!
* \brief Get corresponding device mask
* \return cpu::kDevMask or gpu::kDevMask
*/
inline int dev_mask() const {
if (dev_type == kCPUPinned) return cpu::kDevMask;
return dev_type;
}
/*!
* \brief Comparator, used to enable Context as std::map key.
* \param b another context to compare
* \return compared result
*/
inline bool operator<(const Context &b) const;
/*!
* \brief check if current context equals another one
* \param b another context to compare
* \return whether dev mask and id are same
*/
inline bool operator==(const Context &b) const {
return dev_type == b.dev_type && dev_id == b.dev_id;
}
/*!
* \brief check if current context not equals another one
* \param b another context to compare
* \return whether they are not the same
*/
inline bool operator!=(const Context &b) const {
return !(*this == b);
}
/*!
* \brief save the content into binary stream
* \param strm the output stream
*/
inline void Save(dmlc::Stream *strm) const {
strm->Write(&dev_type, sizeof(dev_type));
strm->Write(&dev_id, sizeof(dev_id));
}
/*!
* \brief load the content from binary stream
* \param strm the output stream
* \return whether the load is successful
*/
inline bool Load(dmlc::Stream *strm) {
if (strm->Read(&dev_type, sizeof(dev_type)) != sizeof(dev_type)) return false;
if (strm->Read(&dev_id, sizeof(int32_t)) != sizeof(int32_t)) return false;
return true;
}
/*! \brief the maximal device type */
static const int32_t kMaxDevType = 4;
/*! \brief the maximal device index */
static const int32_t kMaxDevID = 16;
/*!
* \brief Create a new context.
* \param dev_type device type.
* \param dev_id device id. -1 for current device.
*/
inline static Context Create(DeviceType dev_type, int32_t dev_id = -1);
/*! \return CPU Context */
inline static Context CPU(int32_t dev_id = 0);
/*!
* Create a GPU context.
* \param dev_id the device id.
* \return GPU Context. -1 for current GPU.
*/
inline static Context GPU(int32_t dev_id = -1);
/*!
* Create a pinned CPU context.
* \param dev_id the device id for corresponding GPU.
* \return Pinned CPU context. -1 for current GPU.
*/
inline static Context CPUPinned(int32_t dev_id = -1);
/*!
* Create a context from string of the format [cpu|gpu|cpu_pinned](n)
* \param str the string pattern
* \return Context
*/
inline static Context FromString(std::string str);
};
/*!
* \brief execution time context.
* The information needed in runtime for actual execution.
*/
struct RunContext {
/*!
* \brief the stream of the device, can be NULL or Stream<gpu>* in GPU mode
*/
void *stream;
/*!
* \brief get mshadow stream from Context
* \return the mshadow stream
* \tparam xpu the device type of the stream
*/
template<typename xpu>
inline mshadow::Stream<xpu>* get_stream() const {
return static_cast<mshadow::Stream<xpu>*>(stream);
}
};
} // namespace mxnet
//! \cond Doxygen_Suppress
namespace mxnet {
// implementing Context
inline bool Context::operator<(const Context &b) const {
if (dev_type == b.dev_type) {
return dev_id < b.dev_id;
} else {
return dev_type < b.dev_type;
}
}
inline Context Context::Create(DeviceType dev_type, int32_t dev_id) {
Context ctx;
ctx.dev_type = dev_type;
if (dev_id < 0) {
ctx.dev_id = 0;
#if MXNET_USE_CUDA
if (dev_type != kCPU) {
CHECK_EQ(cudaGetDevice(&ctx.dev_id), cudaSuccess);
}
#endif
} else {
ctx.dev_id = dev_id;
}
return ctx;
}
inline Context Context::CPU(int32_t dev_id) {
return Create(kCPU, dev_id);
}
inline Context Context::CPUPinned(int32_t dev_id) {
return Create(kCPUPinned, dev_id);
}
inline Context Context::GPU(int32_t dev_id) {
return Create(kGPU, dev_id);
}
inline Context Context::FromString(std::string str) {
Context ret;
try {
std::string::size_type l = str.find('(');
CHECK_NE(l, std::string::npos);
std::string::size_type r = str.find(')');
CHECK_EQ(r, str.length()-1);
std::string type = str.substr(0, l);
int id = std::stoi(str.substr(l+1, r-l-1));
if (type == "cpu") {
ret = CPU(id);
} else if (type == "gpu") {
ret = GPU(id);
} else if (type == "cpu_pinned") {
ret = CPUPinned(id);
} else {
LOG(FATAL) << "Invalid context string " << str;
}
} catch (...) {
LOG(FATAL) << "Invalid context string " << str;
}
return ret;
}
inline std::ostream& operator<<(std::ostream &out, const Context &ctx) {
if (ctx.dev_type == Context::kCPU) {
out << "cpu(";
} else if (ctx.dev_type == Context::kGPU) {
out << "gpu(";
} else if (ctx.dev_type == Context::kCPUPinned) {
out << "cpu_pinned(";
} else {
out << "unknown(";
}
out << ctx.dev_id << ")";
return out;
}
} // namespace mxnet
//===== EXPANDING: ../include/mxnet/tensor_blob.h =====
/*!
* Copyright (c) 2014 by Contributors
* \file tensor_blob.h
* \brief TBlob class that holds common representation of
* arbirary dimension tensor, can be used to transformed
* to normal fixed dimenson tensor
* \author Tianqi Chen
*/
#ifndef MXNET_TENSOR_BLOB_H_
#define MXNET_TENSOR_BLOB_H_
#if MXNET_USE_MKL2017 == 1
#endif
namespace mxnet {
/*!
* \brief tensor blob class that can be used to hold tensor of any dimension,
* any device and any data type,
* This is a weak type that can be used to transfer data through interface
* TBlob itself do not involve any arithmentic operations,
* but it can be converted to tensor of fixed dimension for further operations
*
* Like tensor, this data structure is like a pointer class and do not
* implicit allocated, de-allocate space.
* This data structure can be helpful to hold tensors of different dimensions
* and wait for further processing
*/
class TBlob {
public:
/*! \brief pointer to the data */
void *dptr_;
/*! \brief shape of the tensor */
TShape shape_;
/*!
* \brief storing the stride information in x dimension
*/
index_t stride_;
/*! \brief device mask of the corresponding device */
int dev_mask_;
/*! \brief type flag of the tensor blob */
int type_flag_;
/*! \brief storing mkl chunk buffer blob, use for experimental only */
#if MKL_EXPERIMENTAL == 1
std::shared_ptr<MKLMemHolder> Mkl_mem_;
#endif
/*! \brief default constructor, default copy assign will work */
TBlob(void)
: dptr_(NULL), dev_mask_(cpu::kDevMask),
type_flag_(mshadow::DataType<real_t>::kFlag) {
#if MKL_EXPERIMENTAL == 1
Mkl_mem_ = NULL;
#endif
}
/*!
* \brief constructor that construct TBlob from contiguous memory
* \param dptr the pointer to the memory
* \param shape the shape of the data
* \param dev_mask the device mask, can be cpu::kDevMask or gpu::kDevMask
*/
template<typename DType>
TBlob(DType *dptr,
const TShape &shape,
int dev_mask)
: dptr_(dptr), shape_(shape),
stride_(shape[shape.ndim() - 1]),
dev_mask_(dev_mask),
type_flag_(mshadow::DataType<DType>::kFlag) {
#if MKL_EXPERIMENTAL == 1
Mkl_mem_ = NULL;
#endif
}
/*!
* \brief constructor that construct TBlob from contiguous memory
* \param dptr the pointer to the memory
* \param shape the shape of the data
* \param dev_mask the device mask, can be cpu::kDevMask or gpu::kDevMask
* \param type_flag the type flag. Can be one of enum mshadow::dtype
*/
TBlob(void *dptr,
const TShape &shape,
int dev_mask,
int type_flag)
: dptr_(dptr), shape_(shape),
stride_(shape[shape.ndim() - 1]),
dev_mask_(dev_mask),
type_flag_(type_flag) {
#if MKL_EXPERIMENTAL == 1
Mkl_mem_ = NULL;
#endif
}
/*!
* \brief constructor from tensor
* \param src source tensor
* \tparam Device which device the tensor is on
* \tparam dim tensor dimension
* \tparam DType the type of elements in the tensor
*/
template<typename Device, int dim, typename DType>
TBlob(const mshadow::Tensor<Device, dim, DType> &src) { // NOLINT(*)
*this = src;
#if MKL_EXPERIMENTAL == 1
Mkl_mem_ = NULL;
#endif
}
/*!
* \brief assignment from tensor
* \param src source tensor
* \tparam Device which device the tensor is on
* \tparam dim tensor dimension
* \tparam DType the type of elements in the tensor
* \return reference of self
*/
template<typename Device, int dim, typename DType>
inline TBlob
&operator=(const mshadow::Tensor<Device, dim, DType> &src) {
dptr_ = src.dptr_;
shape_ = src.shape_;
stride_ = src.stride_;
dev_mask_ = Device::kDevMask;
type_flag_ = mshadow::DataType<DType>::kFlag;
return *this;
}
/*!
* \return whether the tensor's memory is continuous
*/
inline bool CheckContiguous(void) const {
return shape_[shape_.ndim() - 1] == stride_;
}
/*!
* \brief flatten the tensor to 2 dimension, collapse the higher dimensions together
* \param stream the possible stream target tensor should reside on
* \tparam Device which device the tensor is on
* \tparam DType the type of elements in the tensor
* \return tensor after flatten
*/
template<typename Device, typename DType>
inline mshadow::Tensor<Device, 2, DType> FlatTo2D(
mshadow::Stream<Device> *stream = NULL) const {
CHECK(Device::kDevMask == dev_mask_)
<< "TBlob.get: device type do not match specified type";
CHECK(mshadow::DataType<DType>::kFlag == type_flag_)
<< "TBlob.get_with_shape: data type do not match specified type."
<< "Expected: " << type_flag_ << " v.s. given " << mshadow::DataType<DType>::kFlag;
#if MKL_EXPERIMENTAL == 1
if (Mkl_mem_ != nullptr) {
Mkl_mem_->check_and_prv_to_cpu(dptr_);
}
#endif
return mshadow::Tensor<Device, 2, DType>(static_cast<DType*>(dptr_),
shape_.FlatTo2D(), stride_, stream);
}
/*!
* \brief flatten the tensor to 1 dimension, collapse all the dimensions together.
* \param stream the possible stream target tensor should reside on
* \tparam Device which device the tensor is on
* \tparam DType the type of elements in the tensor
* \return tensor after flatten
*/
template<typename Device, typename DType>
inline mshadow::Tensor<Device, 1, DType> FlatTo1D(
mshadow::Stream<Device> *stream = NULL) const {
return this->get_with_shape<Device, 1, DType>(
mshadow::Shape1(shape_.Size()), stream);
}
/*! \brief return number of dimension of the tensor inside */
inline int ndim(void) const {
return shape_.ndim();
}
/*!
* \brief return size of i-th dimension, start counting from highest dimension
* \param idx the dimension count from the highest dimensin
* \return the size
*/
inline index_t size(index_t idx) const {
return shape_[idx];
}
/*! \brief total number of elements in the tensor */
inline index_t Size(void) const {
return shape_.Size();
}
/*! \brief get pointer in dtype */
template<typename DType>
inline DType* dptr() const {
CHECK(mshadow::DataType<DType>::kFlag == type_flag_)
<< "TBlob.dptr(): data type do not match specified type.";
return static_cast<DType*>(dptr_);
}
/*!
* \brief fetch the tensor, with respect to specific dimension
* if dim do not match the stored dimension, an error will be issued
* \return the tensor requested
* \param stream the possible stream target tensor should reside on
* \tparam Device which device the tensor is on
* \tparam dim dimension of the tensor
* \tparam DType the type of elements in the tensor
*/
template<typename Device, int dim, typename DType>
inline mshadow::Tensor<Device, dim, DType> get(mshadow::Stream<Device> *stream = NULL) const {
CHECK(Device::kDevMask == dev_mask_)
<< "TBlob.get: device type do not match specified type";
CHECK(mshadow::DataType<DType>::kFlag == type_flag_)
<< "TBlob.get_with_shape: data type do not match specified type."
<< "Expected: " << type_flag_ << " v.s. given " << mshadow::DataType<DType>::kFlag;
#if MKL_EXPERIMENTAL == 1
if (Mkl_mem_ != nullptr) {
Mkl_mem_->check_and_prv_to_cpu(dptr_);
}
#endif
return mshadow::Tensor<Device, dim, DType>(static_cast<DType*>(dptr_),
shape_.get<dim>(),
stride_, stream);
}
/*!
* \brief fetch a tensor in given shape
* If size do not match the stored size, an error will be issued
* \return the tensor requested
* \param shape the shape required
* \param stream the possible stream target tensor should reside on
* \tparam Device which device the tensor is on
* \tparam dim dimension of the tensor
* \tparam DType the type of elements in the tensor
*/
template<typename Device, int dim, typename DType>
inline mshadow::Tensor<Device, dim, DType> get_with_shape(
const mshadow::Shape<dim> &shape,
mshadow::Stream<Device> *stream = NULL) const {
CHECK(Device ::kDevMask == dev_mask_)
<< "TBlob.get: device type do not match specified type";
CHECK(mshadow::DataType<DType>::kFlag == type_flag_)
<< "TBlob.get_with_shape: data type do not match specified type."
<< "Expected: " << type_flag_ << " v.s. given " << mshadow::DataType<DType>::kFlag;
CHECK_EQ(this->CheckContiguous(), true) << "TBlob.get_reshape: must be contiguous";
CHECK_EQ(this->shape_.Size(), shape.Size())
<< "TBlob.get_with_shape: new and old shape do not match total elements";
#if MKL_EXPERIMENTAL == 1
if (Mkl_mem_ != nullptr) {
Mkl_mem_->check_and_prv_to_cpu(dptr_);
}
#endif
return mshadow::Tensor<Device, dim, DType>(static_cast<DType*>(dptr_),
shape,
shape[dim - 1],
stream);
}
/*!
* \brief flatten the tensor to 3 dimension,
* collapse the dimension before and after specified axis.
* \param axis The axis specified.
* \param stream the possible stream target tensor should reside on
* \tparam Device which device the tensor is on
* \tparam DType the type of elements in the tensor
* \return tensor after flatten
*/
template<typename Device, typename DType>
inline mshadow::Tensor<Device, 3, DType> FlatTo3D(
int axis, mshadow::Stream<Device> *stream = NULL) const {
return this->get_with_shape<Device, 3, DType>(
this->shape_.FlatTo3D(axis), stream);
}
/*!
* \brief flatten the tensor to 3 dimension,
* collapse the dimension: [0, axis_begin), [axis_begin, axis_end], (axis_end, ndim).
* \param axis_begin The beginning axis specified.
* \param axis_end The ending axis specified.
* \param stream the possible stream target tensor should reside on
* \tparam Device which device the tensor is on
* \tparam DType the type of elements in the tensor
* \return tensor after flatten
*/
template<typename Device, typename DType>
inline mshadow::Tensor<Device, 3, DType> FlatTo3D(
int axis_begin, int axis_end,
mshadow::Stream<Device> *stream = NULL) const {
return this->get_with_shape<Device, 3, DType>(
this->shape_.FlatTo3D(axis_begin, axis_end), stream);
}
};
} // namespace mxnet
namespace dmlc {
// Add a few patches to support TShape in dmlc/parameter.
DMLC_DECLARE_TYPE_NAME(mxnet::TShape, "Shape(tuple)");
DMLC_DECLARE_TYPE_NAME(nnvm::Tuple<int>, "Shape(tuple)");
namespace parameter {
template<>
class FieldEntry<mxnet::TShape>
: public FieldEntryBase<FieldEntry<mxnet::TShape>, mxnet::TShape> {
public:
FieldEntry() : enforce_nonzero_(false), expect_ndim_(0) {}
// parent class
typedef FieldEntryBase<FieldEntry<mxnet::TShape>, mxnet::TShape> Parent;
virtual void Check(void *head) const {
Parent::Check(head);
mxnet::TShape &v = this->Get(head);
if (expect_ndim_ != 0 && v.ndim() != expect_ndim_) {
std::ostringstream os;
os << "value " << v << "for Parameter " << this->key_
<< " has wrong dimensions, expected dimension=" << expect_ndim_;
throw dmlc::ParamError(os.str());
}
if (enforce_nonzero_) {
for (mxnet::index_t i = 0; i < v.ndim(); ++i) {
if (v[i] == 0U) {
std::ostringstream os;
os << "value " << v << "for Parameter " << this->key_
<< " is invalid, the input shape must be nonzero in all dimensions";
throw dmlc::ParamError(os.str());
}
}
}
}
inline FieldEntry<mxnet::TShape> &enforce_nonzero() {
this->enforce_nonzero_ = true;
return this->self();
}
inline FieldEntry<mxnet::TShape> &set_expect_ndim(mxnet::index_t ndim) {
expect_ndim_ = ndim;
return this->self();
}
private:
// whether all the entries need to be nonzero
bool enforce_nonzero_;
// expected number of dimension, default = 0 means no restriction.
mxnet::index_t expect_ndim_;
};
} // namespace parameter
} // namespace dmlc
#endif // MXNET_TENSOR_BLOB_H_
//===== EXPANDED: ../include/mxnet/tensor_blob.h =====
//! \endcond
#endif // MXNET_BASE_H_
//===== EXPANDED: ../include/mxnet/base.h =====
//===== EXPANDING: ../nnvm/src/core/graph.cc =====
/*!
* Copyright (c) 2016 by Contributors
* \file graph_attr_types.cc
* \brief Graph node data structure.
*/
//===== EXPANDING: ../nnvm/include/nnvm/graph.h =====
/*!
* Copyright (c) 2016 by Contributors
* \file graph.h
* \brief Configuation of nnvm as well as basic data structure.
*/
#ifndef NNVM_GRAPH_H_
#define NNVM_GRAPH_H_
namespace nnvm {
class IndexedGraph;
/*!
* \brief Symbolic computation graph.
* This is the intermediate representation for optimization pass.
*/
class Graph {
public:
/*! \brief outputs of the computation graph. */
std::vector<NodeEntry> outputs;
/*!
* \brief attributes of a graph
* Note that attribute is shared pointer and can be shared across graphs.
*
* It is highly recommended to keep each attribute immutable.
* It is also safe to implement an copy-on-write semnatics.
*
* Copy when shared_ptr.unique is not true, while reuse original space
* when shared_ptr.unique is true.
*/
std::unordered_map<std::string, std::shared_ptr<any> > attrs;
/*!
* \brief Get the immutable attribute from attrs.
* \param attr_name the name of the attribute
* \return the reference to corresponding attribute
* \tparam T the type of the attribute.
*/
template<typename T>
inline const T& GetAttr(const std::string& attr_name) const;
/*!
* \brief Get a move copy of the attribute, implement copy on write semantics.
* The content is moved if the reference counter of shared_ptr is 1.
* The attribute is erased from attrs after the call.
*
* \param attr_name the name of the attribute
* \return a new copy of the corresponding attribute.
* \tparam T the type of the attribute.
*/
template<typename T>
inline T MoveCopyAttr(const std::string& attr_name);
/*!
* \brief get a indexed graph of current graph, if not exist, create it on demand
* \return The indexed graph.
* \sa IndexedGraph
*/
const IndexedGraph& indexed_graph();
private:
// internal structure of indexed graph
std::shared_ptr<const IndexedGraph> indexed_graph_;
};
/*!
* \brief Auxililary data structure to index a graph.
* It maps Nodes in the graph to consecutive integers node_id.
* It also maps IndexedGraph::NodeEntry to consecutive integer entry_id.
* This allows storing properties of Node and NodeEntry into
* compact vector and quickly access them without resorting to hashmap.
*
* The node_id and entry_rptr are the same as the JSON graph produced by SaveJSON Pass.
*/
class IndexedGraph {
public:
/*! \brief represents a data in the graph */
struct NodeEntry {
/*! \brief the source node id in the computation graph */
uint32_t node_id;
/*! \brief index of output from the source. */
uint32_t index;
/*! \brief version of the node */
uint32_t version;
};
/*! \brief Node data structure in IndexedGraph */
struct Node {
/*! \brief pointer to the source node */
const nnvm::Node* source;
/*! \brief inputs to the node */
array_view<NodeEntry> inputs;
/*! \brief control flow dependencies to the node */
array_view<uint32_t> control_deps;
};
/*! \return number of nodes in the graph */
inline size_t num_nodes() const {
return nodes_.size();
}
/*! \return total number of NodeEntry in the graph */
inline size_t num_node_entries() const {
return entry_rptr_.back();
}
/*!
* \brief Get a unique entry id between 0 to num_node_entries()
* for a given IndexedGraph::NodeEntry
* \param node_id The node index
* \param index the output index
* \return the unique index.
*/
inline uint32_t entry_id(uint32_t node_id, uint32_t index) const {
return entry_rptr_[node_id] + index;
}
/*!
* \brief Get a unique entry id between 0 to num_node_entries()
* for a given IndexedGraph::NodeEntry
* \param e The entry to query for index.
* \return the unique index.
*/
inline uint32_t entry_id(const NodeEntry& e) const {
return entry_rptr_[e.node_id] + e.index;
}
/*!
* \brief Get a unique entry id between 0 to num_node_entries()
* for a given NodeEntry.
* \param e The entry to query for index.
* \return the unique index.
*/
inline uint32_t entry_id(const nnvm::NodeEntry& e) const {
return entry_rptr_[node_id(e.node.get())] + e.index;
}
/*!
* \brief Get the corresponding node id for a given Node in the IndexedGraph.
* \param node The Node to query for index.
* \return the node index.
*/
inline uint32_t node_id(const nnvm::Node* node) const {
return node2index_.at(node);
}
/*!
* \brief Get the corresponding Node structure for a given node_id.
* \param node_id The node id
* \return const reference to the corresponding IndexedGraph::Node
*/
inline const Node& operator[](uint32_t node_id) const {
return nodes_[node_id];
}
/*!
* \brief Get the corresponding Node structure
* \param node The pointer to the Node structure
* \return const reference to the corresponding IndexedGraph::Node
*/
inline const Node& operator[](const nnvm::Node* node) const {
return nodes_[node_id(node)];
}
/*! \return list of argument nodes */
inline const std::vector<uint32_t>& input_nodes() const {
return input_nodes_;
}
/*! \return list of mutable nodes */
inline const std::unordered_set<uint32_t>& mutable_input_nodes() const {
return mutable_input_nodes_;
}
/*! \return list of output entries */
inline const std::vector<NodeEntry>& outputs() const {
return outputs_;
}
// disalllow copy assign
IndexedGraph(const IndexedGraph&) = delete;
private:
friend class Graph;
/*!
* \brief Constructor an IndexedGraph from normal Graph
* \param other The source graph.
*/
explicit IndexedGraph(const Graph& other);
// Node pointers in CSR structure.
std::vector<Node> nodes_;
// Index to all input nodes.
std::vector<uint32_t> input_nodes_;
// Index to all mutable input nodes.
std::unordered_set<uint32_t> mutable_input_nodes_;
// space to store the outputs entries
std::vector<NodeEntry> outputs_;
// mapping from node to index.
std::unordered_map<const nnvm::Node*, uint32_t> node2index_;
// CSR pointer of node entries
std::vector<size_t> entry_rptr_;
// space to store input entries of each
std::vector<NodeEntry> input_entries_;
// control flow dependencies
std::vector<uint32_t> control_deps_;
};
/*!
* \brief perform a Post Order DFS visit to each node in the graph.
* This order is deterministic and is also topoligical sorted.
* \param heads The heads in the graph.
* \param fvisit a function of type std::function<void(const std::shared_ptr<Node>&)>
* \tparam FVisit The function type to perform the visit.
*/
template<typename FVisit>
inline void DFSVisit(const std::vector<NodeEntry>& heads, FVisit fvisit);
// inline function implementations
template<typename T>
inline const T& Graph::GetAttr(const std::string& attr_name) const {
auto it = attrs.find(attr_name);
CHECK(it != attrs.end())
<< "Cannot find attribute " << attr_name << " in the graph";
return nnvm::get<T>(*it->second);
}
template<typename T>
inline T Graph::MoveCopyAttr(const std::string& attr_name) {
auto it = attrs.find(attr_name);
CHECK(it != attrs.end())
<< "Cannot find attribute " << attr_name << " in the graph";
std::shared_ptr<any> sptr = it->second;
attrs.erase(it);
if (sptr.unique()) {
return std::move(nnvm::get<T>(*sptr));
} else {
return nnvm::get<T>(*sptr);
}
}
template <typename GNode, typename HashType,
typename FVisit, typename HashFunc,
typename InDegree, typename GetInput>
void PostOrderDFSVisit(const std::vector<GNode>& heads,
FVisit fvisit,
HashFunc hash,
InDegree indegree,
GetInput getinput) {
std::vector<std::pair<GNode, uint32_t> > stack;
std::unordered_set<HashType> visited;
for (auto& head : heads) {
HashType head_hash = hash(head);
if (visited.count(head_hash) == 0) {
stack.push_back(std::make_pair(head, 0));
visited.insert(head_hash);
}
while (!stack.empty()) {
std::pair<GNode, uint32_t>& back = stack.back();
if (back.second == indegree(back.first)) {
fvisit(back.first);
stack.pop_back();
} else {
const GNode& input = getinput(back.first, back.second++);
HashType input_hash = hash(input);
if (visited.count(input_hash) == 0) {
stack.push_back(std::make_pair(input, 0));
visited.insert(input_hash);
}
}
}
}
}
template<typename FVisit>
inline void DFSVisit(const std::vector<NodeEntry>& heads,
FVisit fvisit) {
typedef const NodePtr* GNode;
std::vector<GNode> head_nodes(heads.size());
std::transform(heads.begin(), heads.end(), head_nodes.begin(),
[](const NodeEntry& e)->GNode {
return &e.node;
});
PostOrderDFSVisit<GNode, Node*>(
head_nodes,
[fvisit](GNode n) { fvisit(*n); }, // FVisit
[](GNode n)->Node* { return n->get(); }, // HashFunc
[](GNode n)->uint32_t { // InDegree
return (*n)->inputs.size() + (*n)->control_deps.size();
},
[](GNode n, uint32_t index)->GNode { // GetInput
if (index < (*n)->inputs.size()) {
return &(*n)->inputs.at(index).node;
} else {
return &(*n)->control_deps.at(index - (*n)->inputs.size());
}
});
}
} // namespace nnvm
#endif // NNVM_GRAPH_H_
//===== EXPANDED: ../nnvm/include/nnvm/graph.h =====
//===== EXPANDING: ../nnvm/include/nnvm/op_attr_types.h =====
/*!
* Copyright (c) 2016 by Contributors
* \file op_attr_types.h
* \brief Data structures that can appear in operator attributes.
*/
#ifndef NNVM_OP_ATTR_TYPES_H_
#define NNVM_OP_ATTR_TYPES_H_
namespace nnvm {
// These types are optional attributes in each operator.
// Each attribute can be required by some passes.
/*!
* \brief Return list of input arguments names of each operator.
*
* \param attrs The attributes of the node.
* \return list of inputs
* \note Register under "FListInputNames", default return {"data"}.
*
* FListInputNames enables automatic variable creation for missing arguments.
*/
using FListInputNames = std::function<std::vector<std::string> (const NodeAttrs& attrs)>;
/*!
* \brief Return number of visible outputs by the user.
*
* \param attrs The attributes of the node.
*
* \note Register under "FNumVisibleOutputs", default not registered.
* This can be used to hide certain output from the user,
* but the additional outputs can be used to pass information from
* forward to gradient pass.
*/
using FNumVisibleOutputs = std::function<uint32_t (const NodeAttrs& attrs)>;
/*!
* \brief Return list of output arguments names of each operator.
*
* \param attrs The attributes of the node.
* \return list of inputs
* \note Register under "FListOutputNames", default return {"outputs"}.
*
* FListOutputNames customized naming for operator outputs.
*/
using FListOutputNames = std::function<std::vector<std::string> (const NodeAttrs& attrs)>;
/*!
* \brief Check whether operator will mutate k-th input.
* \param attrs The attributes of the node.
* \return list of input indices it mutates.
*
* \note Register under "FMutateInputs", default return false
* FMutateInputs enables mutation order handling correctly.
*/
using FMutateInputs = std::function<std::vector<uint32_t> (const NodeAttrs& attrs)>;
/*!
* \brief Inference function of certain type.
* \tparam AttrType The type of the attribute to be infered.
* \return whether all attributes are inferred.
*/
template<typename AttrType>
using FInferNodeEntryAttr = std::function<bool (const NodeAttrs& attrs,
std::vector<AttrType> *in_attrs,
std::vector<AttrType> *out_attrs)>;
/*!
* \brief Shape inference function.
* Update the shapes given the input shape information.
* TShape.ndim() == 0 means the shape is still unknown.
*
* \note Register under "FInferShape",
* by default do not update any shapes.
*
* FInferShape is needed by shape inference
*/
using FInferShape = FInferNodeEntryAttr<TShape>;
/*!
* \brief Type inference function.
* Update the type given the known type information.
*
* \note Register under "FInferType",
* by default set all the output types to 0.
*/
using FInferType = FInferNodeEntryAttr<int>;
/*!
* \brief Whether this op is an explicit backward operator,
* If TIsBackward is true:
* - The first control_deps of the node points to the corresponding forward operator.
*
* \note Register under "TIsBackward"
* This enables easier shape/type inference for backward operators.
*/
using TIsBackward = bool;
/*!
* \brief Get possible inplace options.
* This function enables optimization to reuse memory of inputs in output.
* \param attrs The attributes of the node
* \param in_data The input data.
* \param out_data The output data.
* \return list of pair of that maps input->output,
* indicating possible in place operations.
*
* \note Register under "FInplaceOption", by default no inplace can happen.
*/
using FInplaceOption = std::function<
std::vector<std::pair<int, int> > (const NodeAttrs& attrs)>;
/*!
* \brief Get list of inputs in the op whose content are actually not used by the operator
* These are dummy input that can be used for example in zeros_like, ones_like.
*
* \param attrs The attributes of the node
* \return list input index that are not used by the operator.
*
* \note Register under "FIgnoreInputs".
*/
using FIgnoreInputs = std::function<
std::vector<uint32_t> (const NodeAttrs& attrs)>;
/*!
* \brief Get the gradient node of the op node
* This function generates the backward graph of the node
* \param nodeptr The node to take gradient
* \param out_grads Gradient of current node's outputs
* \return gradients of the inputs
*
* \note Register under "FGradient"
*/
using FGradient = std::function<std::vector<NodeEntry>(
const NodePtr& nodeptr,
const std::vector<NodeEntry>& out_grads)>;
/*!
* \brief Set the attributes of input variable.
* Usually used for setting initialization or weight decay.
* \param attrs The attributes of this node.
* \param var the input variable
* \param index index of var in all inputs
*/
using FSetInputVarAttrOnCompose = std::function<void(
const NodeAttrs& attrs,
NodePtr var,
const int index)>;
} // namespace nnvm
#endif // NNVM_OP_ATTR_TYPES_H_
//===== EXPANDED: ../nnvm/include/nnvm/op_attr_types.h =====
namespace nnvm {
const IndexedGraph& Graph::indexed_graph() {
if (indexed_graph_ == nullptr) {
indexed_graph_.reset(new IndexedGraph(*this));
}
return *indexed_graph_;
}
// implement constructor from graph
IndexedGraph::IndexedGraph(const Graph &g) {
entry_rptr_.push_back(0);
std::vector<size_t> inputs_rptr{0}, control_rptr{0};
DFSVisit(g.outputs, [this, &inputs_rptr, &control_rptr]
(const NodePtr& n) {
CHECK_LT(nodes_.size(), std::numeric_limits<uint32_t>::max());
uint32_t nid = static_cast<uint32_t>(nodes_.size());
// nodes_
IndexedGraph::Node new_node;
new_node.source = n.get();
nodes_.emplace_back(std::move(new_node));
// arg_nodes_
if (n->is_variable()) {
input_nodes_.push_back(nid);
}
// node2index_
node2index_[n.get()] = nid;
// entry rptr
entry_rptr_.push_back(entry_rptr_.back() + n->num_outputs());
// input entries
for (const auto& e : n->inputs) {
auto it = node2index_.find(e.node.get());
CHECK(it != node2index_.end() && it->first == e.node.get());
input_entries_.emplace_back(NodeEntry{it->second, e.index, e.version});
}
inputs_rptr.push_back(input_entries_.size());
// control deps
for (const auto& nptr : n->control_deps) {
auto it = node2index_.find(nptr.get());
CHECK(it != node2index_.end() && it->first == nptr.get());
control_deps_.push_back(it->second);
}
control_rptr.push_back(control_deps_.size());
});
for (const auto& e : g.outputs) {
outputs_.emplace_back(NodeEntry{
node2index_.at(e.node.get()), e.index, e.version});
}
static auto& fmutate_inputs = Op::GetAttr<FMutateInputs>("FMutateInputs");
std::unordered_set<uint32_t> mutable_inputs;
// setup array view
// input_entries_ and control_rptr must not change after this step.
const NodeEntry* iptr = dmlc::BeginPtr(input_entries_);
for (size_t nid = 0; nid < nodes_.size(); ++nid) {
nodes_[nid].inputs = array_view<NodeEntry>(
iptr + inputs_rptr[nid], iptr + inputs_rptr[nid + 1]);
if (nodes_[nid].source->op() != nullptr &&
fmutate_inputs.count(nodes_[nid].source->op())) {
for (uint32_t i : fmutate_inputs[nodes_[nid].source->op()](nodes_[nid].source->attrs)) {
mutable_input_nodes_.insert(nodes_[nid].inputs[i].node_id);
}
}
}
const uint32_t* cptr = dmlc::BeginPtr(control_deps_);
for (size_t nid = 0; nid < nodes_.size(); ++nid) {
nodes_[nid].control_deps = array_view<uint32_t>(
cptr + control_rptr[nid], cptr + control_rptr[nid + 1]);
}
}
} // namespace nnvm
//===== EXPANDED: ../nnvm/src/core/graph.cc =====
//===== EXPANDING: ../nnvm/src/core/op.cc =====
/*!
* Copyright (c) 2016 by Contributors
* \file op.cc
* \brief Support for operator registry.
*/
namespace dmlc {
// enable registry
DMLC_REGISTRY_ENABLE(nnvm::Op);
} // namespace dmlc
namespace nnvm {
// single manager of operator information.
struct OpManager {
// mutex to avoid registration from multiple threads.
// recursive is needed for trigger(which calls UpdateAttrMap)
std::recursive_mutex mutex;
// global operator counter
std::atomic<int> op_counter{0};
// storage of additional attribute table.
std::unordered_map<std::string, std::unique_ptr<any> > attr;
// storage of existing triggers
std::unordered_map<std::string, std::vector<std::function<void(Op*)> > > tmap;
// group of each operator.
std::vector<std::unordered_set<std::string> > op_group;
// get singleton of the
static OpManager* Global() {
static OpManager inst;
return &inst;
}
};
// constructor
Op::Op() {
OpManager* mgr = OpManager::Global();
index_ = mgr->op_counter++;
}
Op& Op::add_alias(const std::string& alias) { // NOLINT(*)
dmlc::Registry<Op>::Get()->AddAlias(this->name, alias);
return *this;
}
// find operator by name
const Op* Op::Get(const std::string& name) {
const Op* op = dmlc::Registry<Op>::Find(name);
CHECK(op != nullptr)
<< "Operator " << name << " is not registered";
return op;
}
// Get attribute map by key
const any* Op::GetAttrMap(const std::string& key) {
auto& dict = OpManager::Global()->attr;
auto it = dict.find(key);
if (it != dict.end()) {
return it->second.get();
} else {
return nullptr;
}
}
// update attribute map
void Op::UpdateAttrMap(const std::string& key,
std::function<void(any*)> updater) {
OpManager* mgr = OpManager::Global();
std::lock_guard<std::recursive_mutex>(mgr->mutex);
std::unique_ptr<any>& value = mgr->attr[key];
if (value.get() == nullptr) value.reset(new any());
if (updater != nullptr) updater(value.get());
}
void Op::AddGroupTrigger(const std::string& group_name,
std::function<void(Op*)> trigger) {
OpManager* mgr = OpManager::Global();
std::lock_guard<std::recursive_mutex>(mgr->mutex);
auto& tvec = mgr->tmap[group_name];
tvec.push_back(trigger);
auto& op_group = mgr->op_group;
for (const Op* op : dmlc::Registry<Op>::List()) {
if (op->index_ < op_group.size() &&
op_group[op->index_].count(group_name) != 0) {
trigger((Op*)op); // NOLINT(*)
}
}
}
Op& Op::include(const std::string& group_name) {
OpManager* mgr = OpManager::Global();
std::lock_guard<std::recursive_mutex>(mgr->mutex);
auto it = mgr->tmap.find(group_name);
if (it != mgr->tmap.end()) {
for (auto& trigger : it->second) {
trigger(this);
}
}
auto& op_group = mgr->op_group;
if (index_ >= op_group.size()) {
op_group.resize(index_ + 1);
}
op_group[index_].insert(group_name);
return *this;
}
} // namespace nnvm
//===== EXPANDED: ../nnvm/src/core/op.cc =====
//===== EXPANDING: ../nnvm/src/core/symbolic.cc =====
/*!
* Copyright (c) 2016 by Contributors
* \file symbolic.cc
* \brief Symbolic graph composition API.
*/
namespace nnvm {
namespace symbol_constants {
const char *kNamespaceSeparator = "$";
} // namespace symbol_constants
// auxililary version attribute in variable.
struct VariableParam {
uint32_t version{0};
};
NodePtr CreateVariableNode(const std::string& name) {
NodePtr n = Node::Create();
n->attrs.op = nullptr;
n->attrs.name = name;
n->attrs.parsed = VariableParam();
return n;
}
// scan over a node's input, update the version to latest
// If the node's op mutates a certain input variable,
// The version of that varaible will increase
// version is used to implicitly order the mutation sequences
inline void UpdateNodeVersion(Node *n) {
static auto& fmutate_inputs = Op::GetAttr<FMutateInputs>("FMutateInputs");
for (NodeEntry& e : n->inputs) {
if (e.node->is_variable()) {
e.version = nnvm::get<VariableParam>(e.node->attrs.parsed).version;
}
}
if (fmutate_inputs.count(n->op()) != 0) {
for (uint32_t i : fmutate_inputs[n->op()](n->attrs)) {
NodeEntry& e = n->inputs[i];
CHECK(e.node->is_variable())
<< "Mutation target can only be Variable";
// increase the version of the variable.
e.version = ++nnvm::get<VariableParam>(e.node->attrs.parsed).version;
}
}
}
inline std::string DefaultVarName(const std::string &op_name,
const std::string &arg_name) {
if (op_name.length() == 0) {
return arg_name;
} else {
return op_name + '_' + arg_name;
}
}
inline void KeywordArgumentMismatch(const char *source,
const std::vector<std::string>& user_args,
const array_view<std::string>& args) {
std::unordered_set<std::string> keys(args.begin(), args.end());
std::ostringstream head, msg;
msg << "\nCandidate arguments:\n";
for (size_t i = 0; i < args.size(); ++i) {
msg << "\t[" << i << ']' << args[i] << '\n';
}
for (const auto& key : user_args) {
if (keys.count(key) == 0) {
LOG(FATAL) << source
<< "Keyword argument name " << key << " not found."
<< msg.str();
}
}
}
template<typename T>
inline std::vector<std::string> GetKeys(
const std::unordered_map<std::string, T>& kwargs) {
std::vector<std::string> keys(kwargs.size());
std::transform(kwargs.begin(), kwargs.end(), keys.begin(),
[](decltype(*kwargs.begin())& kv) { return kv.first; });
return keys;
}
// whether the symbol is atomic functor
inline bool IsAtomic(const std::vector<NodeEntry>& outputs) {
return outputs[0].node->inputs.size() == 0 &&
outputs[0].node->control_deps.size() == 0;
}
// public functions
Symbol Symbol::Copy() const {
std::unordered_map<Node*, NodePtr> old_new;
// use DFSVisit to copy all the nodes
DFSVisit(this->outputs, [&old_new](const NodePtr& node) {
NodePtr np = Node::Create();
np->attrs = node->attrs;
old_new[node.get()] = std::move(np);
});
// connect nodes of new graph
for (const auto &kv : old_new) {
for (const NodeEntry& e : kv.first->inputs) {
Node *ptr = e.node.get();
kv.second->inputs.emplace_back(NodeEntry{old_new[ptr], e.index, e.version});
}
for (const NodePtr& p : kv.first->control_deps) {
kv.second->control_deps.emplace_back(old_new[p.get()]);
}
}
// set the head
Symbol ret;
for (const NodeEntry &e : outputs) {
ret.outputs.emplace_back(NodeEntry{old_new[e.node.get()], e.index, e.version});
}
return ret;
}
void Symbol::Print(std::ostream &os) const {
if (outputs.size() == 1 &&
outputs[0].node->inputs.size() == 0 &&
outputs[0].node->control_deps.size() == 0) {
if (outputs[0].node->is_variable()) {
os << "Variable:" << outputs[0].node->attrs.name << '\n';
} else {
os << "AtomicFunctor "<< " Op:" << outputs[0].node->op()->name << '\n';
}
} else {
// use DFSVisit to copy all the nodes
os << "Symbol Outputs:\n";
for (size_t i = 0; i < outputs.size(); ++i) {
os << "\toutput[" << i << "]=" << outputs[i].node->attrs.name
<< '(' << outputs[i].index << ")\n";
}
DFSVisit(this->outputs, [&os](const NodePtr& node) {
if (node->is_variable()) {
os << "Variable:" << node->attrs.name << '\n';
} else {
os << "--------------------\n";
os << "Op:" << node->op()->name << ", Name=" << node->attrs.name << '\n'
<< "Inputs:\n";
for (size_t i = 0; i < node->inputs.size(); ++i) {
const NodeEntry& e = node->inputs[i];
os << "\targ[" << i << "]=" << e.node->attrs.name
<< '(' << e.index << ")";
if (e.node->is_variable()) {
os << " version=" << e.version << '\n';
} else {
os << '\n';
}
}
if (!node->attrs.dict.empty()) {
os << "Attrs:\n";
// make an ordered copy because unordered_map doesn't guarantee order.
std::map<std::string, std::string> sorted_dict(
node->attrs.dict.begin(), node->attrs.dict.end());
for (auto &kv : sorted_dict) {
os << '\t' << kv.first << '=' << kv.second << '\n';
}
}
if (node->control_deps.size() != 0) {
os << "Control deps:\n";
for (size_t i = 0; i < node->control_deps.size(); ++i) {
os << "\tcdep[" << i << "]=" << node->control_deps[i]->attrs.name << '\n';
}
}
}
});
}
}
Symbol Symbol::operator[] (size_t index) const {
size_t nreturn = outputs.size();
CHECK_LT(index, nreturn) << "Symbol only accept nonnegative index";
if (nreturn == 1) {
return *this;
} else {
Symbol s;
s.outputs.push_back(outputs[index]);
return s;
}
}
std::vector<NodePtr> Symbol::ListInputs(ListInputOption option) const {
std::vector<NodePtr> ret;
if (option == kAll) {
DFSVisit(this->outputs, [&ret](const NodePtr &node) {
if (node->is_variable()) {
ret.push_back(node);
}
});
} else {
std::unordered_set<Node*> mutable_set;
std::vector<NodePtr> vlist;
static auto& fmutate_inputs = Op::GetAttr<FMutateInputs>("FMutateInputs");
DFSVisit(this->outputs, [&ret, &mutable_set, &vlist](const NodePtr &node) {
if (node->is_variable()) {
vlist.push_back(node);
} else if (fmutate_inputs.count(node->op())) {
for (uint32_t i : fmutate_inputs[node->op()](node->attrs)){
mutable_set.insert(node->inputs[i].node.get());
}
}
});
for (const NodePtr& node : vlist) {
if ((option == kReadOnlyArgs && mutable_set.count(node.get()) == 0) ||
(option == kAuxiliaryStates && mutable_set.count(node.get()) != 0)) {
ret.emplace_back(node);
}
}
}
return ret;
}
std::vector<std::string> Symbol::ListInputNames(ListInputOption option) const {
std::vector<NodePtr> inputs = ListInputs(option);
std::vector<std::string> ret(inputs.size());
for (size_t i = 0; i < inputs.size(); ++i) {
ret[i] = inputs[i]->attrs.name;
}
return ret;
}
std::vector<std::string> Symbol::ListOutputNames() const {
static auto& flist_ouputs = Op::GetAttr<FListOutputNames>("FListOutputNames");
std::vector<std::string> ret;
for (auto &head : outputs) {
if (head.node->is_variable()) {
ret.push_back(head.node->attrs.name);
} else {
const std::string& hname = head.node->attrs.name;
std::string rname;
FListOutputNames fn = flist_ouputs.get(head.node->op(), nullptr);
if (fn != nullptr) {
rname = fn(head.node->attrs)[head.index];
} else {
rname = "output";
if (head.node->num_outputs() != 1) {
std::ostringstream os;
os << rname << head.index;
rname = os.str();
}
}
if (hname.length() == 0) {
ret.push_back(std::move(rname));
} else {
ret.push_back(hname + '_' + rname);
}
}
}
return ret;
}
// compositional logic
void Symbol::Compose(const array_view<const Symbol*>& args,
const std::unordered_map<std::string, const Symbol*>& kwargs,
const std::string& name) {
static auto& flist_inputs = Op::GetAttr<FListInputNames>("FListInputNames");
static auto& fset_attrs = Op::GetAttr<FSetInputVarAttrOnCompose>("FSetInputVarAttrOnCompose");
CHECK(!outputs[0].node->is_variable()) << "Variable cannot be composed";
// parameter check.
for (size_t i = 0; i < args.size(); ++i) {
CHECK_EQ(args[i]->outputs.size(), 1U)
<< "Argument " << i << " is a tuple, single value is required";
}
for (const auto& kv : kwargs) {
CHECK_EQ(kv.second->outputs.size(), 1U)
<< "Keyword Argument " << kv.first << " is a tuple, single value is required";
}
// assign new name
outputs[0].node->attrs.name = name;
// Atomic functor composition.
if (IsAtomic(outputs)) {
Node* n = outputs[0].node.get();
uint32_t n_req = n->num_inputs();
if (n_req != kVarg) {
n->inputs.resize(n_req);
CHECK_LE(args.size(), n_req)
<< "Incorrect number of arguments, requires " << n_req
<< ", provided " << args.size();
for (size_t i = 0; i < args.size(); ++i) {
n->inputs[i] = args[i]->outputs[0];
}
// switch to keyword argument matching
if (args.size() != n_req) {
FListInputNames fn = flist_inputs.get(n->op(), nullptr);
auto arg_names = (fn == nullptr) ? std::vector<std::string>{"data"} : fn(n->attrs);
if (arg_names.size() != n_req) {
LOG(FATAL) << "Not enough argument to call operator " << outputs[0].node->op()->name;
}
size_t nmatched = 0;
for (size_t i = args.size(); i < n_req; ++i) {
auto it = kwargs.find(arg_names[i]);
if (it != kwargs.end() && it->first == arg_names[i]) {
n->inputs[i] = it->second->outputs[0];
++nmatched;
} else {
n->inputs[i] = NodeEntry{
CreateVariableNode(DefaultVarName(name, arg_names[i])), 0, 0};
// copy attribute of parent over automatically created variables
n->inputs[i].node->attrs.dict = n->attrs.dict;
}
}
if (nmatched != kwargs.size()) {
n->inputs.clear();
std::vector<std::string> keys = GetKeys(kwargs);
array_view<std::string> view(dmlc::BeginPtr(arg_names) + args.size(),
dmlc::BeginPtr(arg_names) + arg_names.size());
KeywordArgumentMismatch("Symbol.Compose", keys, view);
}
}
} else {
CHECK_EQ(kwargs.size(), 0U) << "Variable length function do not accept kwargs";
n->inputs.reserve(args.size());
for (const Symbol* s : args) {
n->inputs.push_back(s->outputs[0]);
}
}
UpdateNodeVersion(n);
FSetInputVarAttrOnCompose fn = fset_attrs.get(n->op(), nullptr);
if (fn != nullptr) {
for (size_t i = 0; i < n->inputs.size(); ++i) {
if (n->inputs[i].node->is_variable()) {
fn(n->attrs, n->inputs[i].node, i);
}
}
}
} else {
// general composition
CHECK_EQ(args.size(), 0U)
<< "General composition only support kwargs for now";
size_t nmatched = 0;
size_t arg_counter = 0;
std::unordered_map<Node *, const NodeEntry*> replace_map;
// replace map stores the existing replacement plan for arguments node
auto find_replace_map = [&nmatched, &arg_counter, &args, &kwargs, &replace_map]
(const NodePtr &node) {
if (node->is_variable()) {
if (arg_counter < args.size()) {
replace_map[node.get()] = &(args[arg_counter]->outputs[0]);
++arg_counter;
} else {
// match kwargs
auto kit = kwargs.find(node->attrs.name);
if (kit != kwargs.end()) {
replace_map[node.get()] = &(kit->second->outputs[0]);
++nmatched;
}
}
}
};
DFSVisit(this->outputs, find_replace_map);
if (nmatched == kwargs.size() && arg_counter <= args.size()) {
std::vector<Node*> update_nodes;
std::vector<std::pair<NodeEntry*, const NodeEntry*> > replace_plan;
auto find_replace_plan = [&replace_map, &replace_plan, &update_nodes]
(const NodePtr &node) {
// visit all the childs, find possible replacement
bool repl = false;
for (size_t i = 0; i < node->inputs.size(); ++i) {
NodeEntry *e = &(node->inputs[i]);
if (e->node->is_variable()) {
auto iter = replace_map.find(e->node.get());
if (iter != replace_map.end()) {
replace_plan.push_back(std::make_pair(e, iter->second));
repl = true;
}
}
}
if (repl) update_nodes.push_back(node.get());
};
DFSVisit(this->outputs, find_replace_plan);
for (const auto& kv : replace_plan) {
*(kv.first) = *(kv.second);
}
for (Node* n : update_nodes) {
UpdateNodeVersion(n);
}
} else {
std::vector<std::string> keys = GetKeys(kwargs);
std::vector<std::string> arg_names = ListInputNames(kAll);
array_view<std::string> view(dmlc::BeginPtr(arg_names) + arg_counter,
dmlc::BeginPtr(arg_names) + arg_names.size());
KeywordArgumentMismatch("Symbol.Compose", keys, arg_names);
}
}
}
Symbol Symbol::operator () (const array_view<const Symbol*>& args,
const std::unordered_map<std::string, const Symbol*>& kwargs,
const std::string& name) const {
Symbol s = this->Copy();
s.Compose(args, kwargs, name);
return s;
}
void Symbol::AddControlDeps(const Symbol& src) {
CHECK_EQ(outputs.size(), 1U)
<< "AddControlDeps only works for nongrouped symbol";
Node* n = outputs[0].node.get();
for (const NodeEntry& sp : src.outputs) {
n->control_deps.push_back(sp.node);
}
}
Symbol Symbol::GetInternals() const {
static auto& fnum_vis_output = Op::GetAttr<FNumVisibleOutputs>("FNumVisibleOutputs");
Symbol ret;
DFSVisit(this->outputs, [&ret](const NodePtr& node) {
Node* n = node.get();
if (n->is_variable()) {
// grab version from variable.
VariableParam& param = nnvm::get<VariableParam>(n->attrs.parsed);
ret.outputs.emplace_back(NodeEntry{node, 0, param.version});
} else {
uint32_t nout = n->num_outputs();
if (fnum_vis_output.count(n->op())) {
nout = fnum_vis_output[n->op()](n->attrs);
}
for (uint32_t i = 0; i < nout; ++i) {
ret.outputs.emplace_back(NodeEntry{node, i, 0});
}
}
});
return ret;
}
void Symbol::SetAttrs(const std::vector<std::pair<std::string, std::string> >& attrs) {
Node* node = outputs[0].node.get();
for (const NodeEntry& e : outputs) {
CHECK(node == e.node.get())
<< "Symbol.SetAttrs only works for non-grouped symbol";
}
for (const auto& kv : attrs) {
if (kv.first == "name") {
node->attrs.name = kv.second;
} else {
node->attrs.dict[kv.first] = kv.second;
}
}
if (node->op() != nullptr && node->op()->attr_parser != nullptr) {
node->op()->attr_parser(&(node->attrs));
}
}
bool Symbol::GetAttr(const std::string& key, std::string* out) const {
Node* node = outputs[0].node.get();
for (const NodeEntry& e : outputs) {
if (node != e.node.get()) return false;
}
if (key == "name") {
*out = node->attrs.name;
return true;
}
auto it = node->attrs.dict.find(key);
if (it == node->attrs.dict.end()) return false;
*out = it->second;
return true;
}
std::unordered_map<std::string, std::string> Symbol::ListAttrs(ListAttrOption option) const {
if (option == kRecursive) {
std::unordered_map<std::string, std::string> ret;
DFSVisit(this->outputs, [&ret](const NodePtr& n) {
for (const auto& it : n->attrs.dict) {
ret[n->attrs.name + symbol_constants::kNamespaceSeparator + it.first] = it.second;
}
});
return ret;
} else {
return outputs[0].node->attrs.dict;
}
}
std::vector<std::tuple<std::string, std::string, std::string> >
Symbol::ListAttrsRecursive() const {
std::vector<std::tuple<std::string, std::string, std::string> > ret;
DFSVisit(this->outputs, [&ret](const NodePtr& n) {
for (const auto& it : n->attrs.dict) {
ret.emplace_back(std::make_tuple(n->attrs.name, it.first, it.second));
}
});
return ret;
}
Symbol Symbol::CreateFunctor(const Op* op,
std::unordered_map<std::string, std::string> attrs) {
static auto& fnum_vis_output = Op::GetAttr<FNumVisibleOutputs>("FNumVisibleOutputs");
Symbol s;
NodePtr n = Node::Create();
n->attrs.op = op;
n->attrs.dict = std::move(attrs);
if (n->op()->attr_parser != nullptr) {
n->op()->attr_parser(&(n->attrs));
}
uint32_t nout = n->num_outputs();
if (fnum_vis_output.count(n->op())) {
nout = fnum_vis_output[n->op()](n->attrs);
}
for (uint32_t i = 0; i < nout; ++i) {
s.outputs.emplace_back(NodeEntry{n, i, 0});
}
return s;
}
Symbol Symbol::CreateGroup(const std::vector<Symbol> &symbols) {
Symbol ret;
for (const auto &s : symbols) {
ret.outputs.insert(ret.outputs.end(), s.outputs.begin(), s.outputs.end());
}
return ret;
}
Symbol Symbol::CreateVariable(const std::string& name) {
Symbol s;
s.outputs.emplace_back(NodeEntry{CreateVariableNode(name), 0, 0});
return s;
}
} // namespace nnvm
//===== EXPANDED: ../nnvm/src/core/symbolic.cc =====
//===== EXPANDING: ../nnvm/src/core/node.cc =====
/*!
* Copyright (c) 2016 by Contributors
* \file node.cc
* \brief Graph node data structure.
*/
namespace nnvm {
Node::~Node() {
if (inputs.size() != 0) {
// explicit deletion via DFS
// this is used to avoid stackoverflow caused by chain of deletions
std::vector<Node*> stack{this};
std::vector<NodePtr> to_delete;
while (!stack.empty()) {
Node* n = stack.back();
stack.pop_back();
for (NodeEntry& e : n->inputs) {
if (e.node.unique()) {
stack.push_back(e.node.get());
to_delete.emplace_back(std::move(e.node));
} else {
e.node.reset();
}
}
for (NodePtr& sp : n->control_deps) {
if (sp.unique()) {
stack.push_back(sp.get());
} else {
sp.reset();
}
}
n->inputs.clear();
}
}
}
NodePtr Node::Create() {
return std::make_shared<Node>();
}
} // namespace nnvm
//===== EXPANDED: ../nnvm/src/core/node.cc =====
//===== EXPANDING: ../nnvm/src/core/pass.cc =====
/*!
* Copyright (c) 2016 by Contributors
* \file pass.cc
* \brief Support for pass registry.
*/
//===== EXPANDING: ../nnvm/include/nnvm/pass.h =====
/*!
* Copyright (c) 2016 by Contributors
* \file pass.h
* \brief Pass that can be applied to a graph.
*/
#ifndef NNVM_PASS_H_
#define NNVM_PASS_H_
namespace nnvm {
/*!
* \brief A PassFunction is an "Operator on Graph".
* It takes a source graph and return a graph that may or may
* not be the same as the input one.
*
* A pass function can either change the graph structure (thus,
* generating a new Graph), or add new attributes to the graph.
*
* \param src The graph to be transformed.
* \return The generated graph.
*/
typedef std::function<Graph (Graph src)> PassFunction;
/*!
* \brief Apply a series of pass transformations on the input graph.
* \param src The graph to be transformed.
* \param passes A list of pass names to be applied.
* \return The transformed graph
*/
Graph ApplyPasses(Graph src,
const std::vector<std::string>& passes);
/*!
* \brief Apply one pass to the graph.
* \param src The graph to be transformed.
* \param pass The name of pass to be applied.
* \return The transformed graph.
*/
inline Graph ApplyPass(Graph src, const std::string& pass) {
return ApplyPasses(src, {pass});
}
/*!
* \brief Registry entry for DataIterator factory functions.
*/
struct PassFunctionReg
: public dmlc::FunctionRegEntryBase<PassFunctionReg,
PassFunction> {
/*!
* \brief Whether the pass will change graph structure
* If this is false, the pass will only change attributes.
*/
bool change_graph{false};
/*! \brief dependencies on operator attributes */
std::vector<std::string> op_attr_dependency;
/*! \brief dependencies on attributes in the graph */
std::vector<std::string> graph_attr_dependency;
/*! \brief generated targets of graph attributes */
std::vector<std::string> graph_attr_targets;
/*!
* \brief Set whether this pass will change graph structure.
* \param v If true, the pass will change graph structure.
* \return Reference to self.
*/
PassFunctionReg& set_change_graph(bool v) { // NOLINT(*)
change_graph = v;
return *this;
}
/*!
* \brief Declare that this pass will generate the given graph attribute name
* once it is applied on the graph.
* \param attr_name Name of the graph attribute.
* \return Reference to self.
*/
PassFunctionReg& provide_graph_attr(const std::string& attr_name) { // NOLINT(*)
graph_attr_targets.push_back(attr_name);
return *this;
}
/*!
* \brief Declare this pass requires the given operator attribute to be
* available before being applied on the graph.
* \param attr_name Name of the attribute.
* \return Reference to self.
*/
PassFunctionReg& depend_op_attr(const std::string& attr_name) { // NOLINT(*)
op_attr_dependency.push_back(attr_name);
return *this;
}
/*!
* \brief Declare this pass requires the given graph attribute to be
* available before being applied on the graph.
* \param attr_name Name of the attribute.
* \return Reference to self.
*/
PassFunctionReg& depend_graph_attr(const std::string& attr_name) { // NOLINT(*)
graph_attr_dependency.push_back(attr_name);
return *this;
}
};
/*!
* \def NNVM_REGISTER_PASS
* \brief Macro to register pass fuctions.
*
* \code
* // example of registering a shape inference pass
* NNVM_REGISTER_PASS(InferShape)
* .describe("Shape Inference function, generate graph attributes")
* .provide_graph_attr("data_shape")
* .depend_graph_attr("indexed_graph")
* .depend_op_attr("infer_shape")
* .set_body([](const Graph& g) {
* // shape inference logic
* });
* \endcode
*/
#define NNVM_REGISTER_PASS(name) \
DMLC_REGISTRY_REGISTER(::nnvm::PassFunctionReg, PassFunctionReg, name)
} // namespace nnvm
#endif // NNVM_PASS_H_
//===== EXPANDED: ../nnvm/include/nnvm/pass.h =====
namespace dmlc {
// enable registry
DMLC_REGISTRY_ENABLE(nnvm::PassFunctionReg);
} // namespace dmlc
namespace nnvm {
const PassFunctionReg* FindPassDep(const std::string&attr_name) {
for (auto* r : dmlc::Registry<PassFunctionReg>::List()) {
for (auto& s : r->graph_attr_targets) {
if (s == attr_name) return r;
}
}
return nullptr;
}
Graph ApplyPasses(Graph g,
const std::vector<std::string>& pass) {
std::vector<const PassFunctionReg*> fpass;
for (auto& name : pass) {
auto* reg = dmlc::Registry<PassFunctionReg>::Find(name);
CHECK(reg != nullptr)
<< "Cannot find pass " << name << " in the registry";
fpass.push_back(reg);
}
for (auto r : fpass) {
for (auto& dep : r->graph_attr_dependency) {
if (g.attrs.count(dep) == 0) {
auto* pass_dep = FindPassDep(dep);
std::string msg;
if (pass_dep != nullptr) {
msg = " The attribute is provided by pass " + pass_dep->name;
}
LOG(FATAL) << "Graph attr dependency " << dep
<< " is required by pass " << r->name
<< " but is not available "
<< msg;
}
}
g = r->body(std::move(g));
}
return g;
}
} // namespace nnvm
//===== EXPANDED: ../nnvm/src/core/pass.cc =====
//===== EXPANDING: ../nnvm/src/pass/gradient.cc =====
/*!
* Copyright (c) 2016 by Contributors
* \file gradients.cc
* \brief Passes that takes gradient of the graph
* This code code was modified based on mxnet codebase by Min Lin
*/
namespace nnvm {
namespace pass {
namespace {
// default aggregate gradient function
// require operator __zero__ and __ewise_sum__ to be presented.
NodeEntry DefaultAggregateGradient(std::vector<NodeEntry>&& v) {
if (v.size() == 1) {
return std::move(v[0]);
} else if (v.size() == 0) {
NodePtr zero_node = Node::Create();
zero_node->attrs.op = Op::Get("__zero__");
return NodeEntry{zero_node, 0, 0};
} else {
NodePtr sum_node = Node::Create();
sum_node->attrs.op = Op::Get("__ewise_sum__");
sum_node->inputs = std::move(v);
return NodeEntry{sum_node, 0, 0};
}
}
// helper entry
struct GradEntry {
#ifdef _MSC_VER
NodeEntry sum = NodeEntry{nullptr, 0, 0};
#else
NodeEntry sum{nullptr, 0, 0};
#endif
std::vector<NodeEntry> grads;
bool need_attr_hint{true};
};
Graph Gradient(Graph src) {
using nnvm::FGradient;
using MirrorFun = std::function<int (const Node& node)>;
using AttrHintFun = std::function<NodeEntry (const NodeEntry& src, const NodeEntry &like)>;
CHECK_NE(src.attrs.count("grad_ys"), 0U)
<< "Gradient require grad_ys to be presented.";
CHECK_NE(src.attrs.count("grad_ys_out_grad"), 0U)
<< "Gradient require grad_ys_out_grad to be presented.";
CHECK_NE(src.attrs.count("grad_xs"), 0U)
<< "Gradient require grad_xs to be presented.";
const std::vector<NodeEntry>& ys =
src.GetAttr<std::vector<NodeEntry> >("grad_ys");
const std::vector<NodeEntry>& ys_out_grad =
src.GetAttr<std::vector<NodeEntry> >("grad_ys_out_grad");
const std::vector<NodeEntry>& xs =
src.GetAttr<std::vector<NodeEntry> >("grad_xs");
using AggFun = std::function<NodeEntry (std::vector<NodeEntry>&& inputs)>;
AggFun agg_fun = DefaultAggregateGradient;
if (src.attrs.count("grad_aggregate_fun") != 0) {
agg_fun = src.GetAttr<AggFun>("grad_aggregate_fun");
}
MirrorFun mirror_fun = nullptr;
if (src.attrs.count("grad_mirror_fun") != 0) {
mirror_fun = src.GetAttr<MirrorFun>("grad_mirror_fun");
}
AttrHintFun attr_hint_fun = nullptr;
if (src.attrs.count("attr_hint_fun") != 0) {
attr_hint_fun = src.GetAttr<AttrHintFun>("attr_hint_fun");
}
// topo sort
std::vector<NodePtr> topo_order;
std::unordered_map<Node*, std::vector<GradEntry> > output_grads;
DFSVisit(ys, [&](const NodePtr& node) {
if (output_grads.count(node.get()) == 0) {
output_grads[node.get()].resize(node->num_outputs());
}
topo_order.push_back(node);
});
CHECK_EQ(ys.size(), ys_out_grad.size());
for (size_t i = 0; i < ys.size(); ++i) {
NodeEntry ograd = ys_out_grad[i];
output_grads[ys[i].node.get()][ys[i].index].grads = { ograd };
}
// construct mirror reduece memory strategy if needed
std::unordered_map<Node*, NodePtr> mirror_map;
if (mirror_fun != nullptr) {
for (const NodePtr& n : topo_order) {
if (mirror_fun(*n)) {
NodePtr new_node = Node::Create();
*new_node = *n;
new_node->attrs.name += "_mirror";
for (auto& e : new_node->inputs) {
e.node = mirror_map.at(e.node.get());
}
for (auto& n : new_node->control_deps) {
n = mirror_map.at(n.get());
}
mirror_map[n.get()] = std::move(new_node);
} else {
mirror_map[n.get()] = n;
}
}
}
// traverse backward
static auto& grad_fun_map = Op::GetAttr<FGradient>("FGradient");
static auto& finfer_shape = Op::GetAttr<FInferShape>("FInferShape");
std::vector<NodeEntry> out_agg_grads;
for (auto rit = topo_order.rbegin(); rit != topo_order.rend(); ++rit) {
const NodePtr& ptr = *rit;
if (ptr->is_variable()) continue;
out_agg_grads.clear();
auto& out_grad_vec = output_grads.at(ptr.get());
for (uint32_t i = 0; i < out_grad_vec.size(); ++i) {
GradEntry& e = out_grad_vec[i];
e.sum = agg_fun(std::move(e.grads));
if (e.need_attr_hint && attr_hint_fun != nullptr) {
e.sum = attr_hint_fun(e.sum, NodeEntry{ptr, 0, i});
}
out_agg_grads.push_back(e.sum);
}
if ((*rit)->inputs.size() != 0) {
NodePtr fwd_node = (mirror_map.size() == 0 ? ptr : mirror_map.at(ptr.get()));
std::vector<NodeEntry> input_grads = grad_fun_map[ptr->op()](
fwd_node, out_agg_grads);
CHECK_EQ((*rit)->inputs.size(), input_grads.size())
<< "Gradient function not returning enough gradient";
auto git = input_grads.begin();
for (auto it = (*rit)->inputs.begin(); it != (*rit)->inputs.end(); ++it, ++git) {
auto& ge = output_grads[it->node.get()][it->index];
// if any of the backward op can do shape inference, the hint is not necessary.
if (finfer_shape.count(git->node->op())) {
ge.need_attr_hint = false;
}
ge.grads.emplace_back(std::move(*git));
}
}
}
// take out the xs' grads
Graph ret;
ret.outputs.reserve(xs.size());
for (const NodeEntry& e : xs) {
GradEntry& entry = output_grads[e.node.get()][e.index];
// aggregate sum if there haven't been
if (entry.sum.node.get() == nullptr) {
entry.sum = agg_fun(std::move(entry.grads));
if (entry.need_attr_hint && attr_hint_fun != nullptr) {
entry.sum = attr_hint_fun(entry.sum, e);
}
}
ret.outputs.emplace_back(std::move(entry.sum));
}
return ret;
}
// register pass
NNVM_REGISTER_PASS(Gradient)
.describe("Return a gradient graph of src.attrs[\"ys\"] wrt src.attrs[\"xs\"]")
.set_body(Gradient)
.set_change_graph(true)
.depend_graph_attr("grad_ys")
.depend_graph_attr("grad_xs")
.depend_graph_attr("grad_ys_out_grad");
} // namespace
} // namespace pass
} // namespace nnvm
//===== EXPANDED: ../nnvm/src/pass/gradient.cc =====
//===== EXPANDING: ../nnvm/src/pass/order_mutation.cc =====
/*!
* Copyright (c) 2016 by Contributors
* \file order_mutation.cc
* \brief Add control flow dependencies between nodes
* To correctly order mutation and read to resolve
* write after read problem and read after write problems.
*/
namespace nnvm {
namespace pass {
namespace {
template<typename T>
inline T get_with_default(const std::unordered_map<Node*, T> &map,
Node* key,
const T& def) {
auto it = map.find(key);
if (it != map.end()) return it->second;
return def;
}
inline bool IsMutate(const std::vector<uint32_t>& mutate_inputs, uint32_t i) {
return std::binary_search(mutate_inputs.begin(), mutate_inputs.end(), i);
}
Graph OrderMutation(const Graph& src) {
std::unordered_map<Node*, std::vector<NodeEntry> > version_hist;
DFSVisit(src.outputs, [&version_hist](const NodePtr& n) {
for (const NodeEntry& e : n->inputs) {
if (e.node->is_variable()) {
if (e.version != 0 && version_hist.count(e.node.get()) == 0) {
version_hist[e.node.get()] = std::vector<NodeEntry>{};
}
}
}
});
// no mutation happens, everything if fine.
if (version_hist.size() == 0) return src;
// start preparing for remapping the nodes.
std::unordered_map<Node*, NodePtr> old_new;
auto prepare = [&version_hist, &old_new] (const NodePtr& n) {
static auto& fmutate_inputs = Op::GetAttr<FMutateInputs>("FMutateInputs");
std::vector<uint32_t> mutate_inputs;
if (!n->is_variable() && fmutate_inputs.count(n->op())) {
mutate_inputs = fmutate_inputs[n->op()](n->attrs);
}
std::sort(mutate_inputs.begin(), mutate_inputs.end());
bool need_repl = false;
for (size_t i = 0; i < n->inputs.size(); ++i) {
const NodeEntry& e = n->inputs[i];
if (e.node->is_variable()) {
if (e.version != 0) need_repl = true;
auto it = version_hist.find(e.node.get());
if (it != version_hist.end()) {
std::vector<NodeEntry>& vec = it->second;
vec.emplace_back(NodeEntry{n, IsMutate(mutate_inputs, i), e.version});
}
} else {
if (old_new.count(e.node.get()) != 0) need_repl = true;
}
}
for (const NodePtr& p : n->control_deps) {
if (old_new.count(p.get()) != 0) need_repl = true;
}
if (need_repl) {
NodePtr np = Node::Create();
np->attrs = n->attrs;
old_new[n.get()] = std::move(np);
}
};
DFSVisit(src.outputs, prepare);
// comparator of history entry
auto comparator = [](const NodeEntry& a, const NodeEntry &b) {
if (a.version < b.version) return true;
if (a.version > b.version) return false;
return a.index > b.index;
};
for (auto &kv : version_hist) {
std::sort(kv.second.begin(), kv.second.end(), comparator);
}
// copy the nodes, as well as add control deps
for (auto &kv : old_new) {
// copy the nodes
for (const NodeEntry& e : kv.first->inputs) {
auto it = old_new.find(e.node.get());
if (it != old_new.end()) {
kv.second->inputs.emplace_back(NodeEntry{it->second, e.index, e.version});
} else {
kv.second->inputs.push_back(e);
}
}
for (const NodePtr& p : kv.first->control_deps) {
kv.second->control_deps.emplace_back(
get_with_default(old_new, p.get(), p));
}
// add control deps
static auto& fmutate_inputs = Op::GetAttr<FMutateInputs>("FMutateInputs");
std::vector<uint32_t> mutate_inputs;
if (fmutate_inputs.count(kv.first->op())) {
mutate_inputs = fmutate_inputs[kv.first->op()](kv.first->attrs);
}
std::sort(mutate_inputs.begin(), mutate_inputs.end());
for (size_t i = 0; i < kv.first->inputs.size(); ++i) {
const NodeEntry& e = kv.first->inputs[i];
if (e.node->is_variable() && version_hist.count(e.node.get()) != 0) {
std::vector<NodeEntry>& vec = version_hist.at(e.node.get());
auto it = std::lower_bound(vec.begin(), vec.end(),
NodeEntry{nullptr, 1, e.version},
comparator);
if (IsMutate(mutate_inputs, i)) {
int read_dep = 0;
while (it != vec.begin()) {
--it;
if (it->index != 0) break;
++read_dep;
// depend on previous read
kv.second->control_deps.push_back(
get_with_default(old_new, it->node.get(), it->node));
}
if (read_dep == 0 && it->index != 0) {
// depend on last write
kv.second->control_deps.push_back(
get_with_default(old_new, it->node.get(), it->node));
}
} else {
// depend on last write
if (it->index != 0) {
kv.second->control_deps.push_back(
get_with_default(old_new, it->node.get(), it->node));
}
}
}
}
}
Graph ret;
for (const NodeEntry &e : src.outputs) {
ret.outputs.emplace_back(NodeEntry{
get_with_default(old_new, e.node.get(), e.node), e.index, e.version});
}
return ret;
}
NNVM_REGISTER_PASS(OrderMutation)
.describe("Return a new graph that adds control dependencies, "\
"to order the mutation and reads if mutation exists.")
.set_body(OrderMutation)
.set_change_graph(true);
} // namespace
} // namespace pass
} // namespace nnvm
//===== EXPANDED: ../nnvm/src/pass/order_mutation.cc =====
//===== EXPANDING: ../nnvm/src/pass/plan_memory.cc =====
/*!
* Copyright (c) 2016 by Contributors
* \file plan_memory.cc
* \brief Assign memory tag to each of the data entries.
*/
//===== EXPANDING: ../nnvm/include/nnvm/graph_attr_types.h =====
/*!
* Copyright (c) 2016 by Contributors
* \file graph_attr_types.h
* \brief Data structures that can appear in graph attributes.
*/
#ifndef NNVM_GRAPH_ATTR_TYPES_H_
#define NNVM_GRAPH_ATTR_TYPES_H_
namespace nnvm {
/*!
* \brief The result holder of JSON serializer
*
* \note Stored under ret.attrs["json"], provided by Pass "SaveJSON"
* \code
* Graph ret = ApplyPass(src_graph, "SaveJSON");
* const JSONString& json = ret.GetAttr<JSONString>("shape");
* \endcode
*/
using JSONString = std::string;
/*!
* \brief The result holder of shape of each NodeEntry in the graph.
* \note Stored under graph.attrs["shape"], provided by Pass "InferShape"
*
* \code
* Graph g = ApplyPass(src_graph, "InferShape");
* const ShapeVector& shapes = g.GetAttr<ShapeVector>("shape");
* // get shape by entry id
* TShape entry_shape = shapes[g.indexed_graph().entry_id(my_entry)];
* \endcode
*
* \sa FInferShape
*/
using ShapeVector = std::vector<TShape>;
/*!
* \brief The result holder of type of each NodeEntry in the graph.
* \note Stored under graph.attrs["dtype"], provided by Pass "InferType"
*
* \code
* Graph g = ApplyPass(src_graph, "InferType");
* const DTypeVector& types = g.GetAttr<DTypeVector>("dtype");
* // get shape by entry id
* int entry_type = dtypes[g.indexed_graph().entry_id(my_entry)];
* \endcode
*
* \sa FInferType
*/
using DTypeVector = std::vector<int>;
/*!
* \brief The result holder of device of each operator in the graph.
* \note Stored under graph.attrs["device"], provided by Pass "PlaceDevice"
*
* \code
* Graph g = ApplyPass(src_graph, "PlaceDevice");
* const &device = g.GetAttr<DeviceVector>("device");
* // get device by node_id
* int device_type = device[g.indexed_graph().node_id(my_node)];
* \endcode
*/
using DeviceVector = std::vector<int>;
/*!
* \brief The result holder of device of each operator in the graph.
*
* \note Stored under graph.attrs["device_assign_map"], needed by Pass "PlaceDevice"
* -1 means unknown device
*/
using DeviceAssignMap = std::unordered_map<std::string, int>;
/*!
* \brief The result holder of storage id of each NodeEntry in the graph.
*
* \note Stored under graph.attrs["storage"], provided by Pass "PlanMemory"
* Storage id is a continuous integer.
* If the storage id is -1 then the storage is not assigned.
*
* \code
* Graph g = ApplyPass(src_graph, "PlanMemory");
* const &storage = g.GetAttr<StorageVector>("storage");
* // get storage id by entry
* int storage_id = storage[g.indexed_graph().entry_id(my_entry)];
* \endcode
*/
using StorageVector = std::vector<int>;
} // namespace nnvm
#endif // NNVM_GRAPH_ATTR_TYPES_H_
//===== EXPANDED: ../nnvm/include/nnvm/graph_attr_types.h =====
//===== EXPANDING: ../nnvm/src/pass/graph_algorithm.h =====
/*!
* Copyright (c) 2016 by Contributors
* \file graph_algorithm.h
* \brief This header contains graph algorithms on StaticGraph.
* It is used compute informations such as whether two
* operations can run in parallel, and helps allocation.
*/
#ifndef NNVM_PASS_GRAPH_ALGORITHM_H_
#define NNVM_PASS_GRAPH_ALGORITHM_H_
namespace nnvm {
namespace pass {
/*!
* \brief Find best path in the DAG, with reward defined
* by sum of reward of each node along the path.
* \param graph the original static graph.
* \param topo_order topo order of the nodes in the graph.
* \param node_reward the reward of each node.
* \param path the output path of nodes.
* \return the total reward of best path.
*/
inline uint32_t FindBestPath(
const IndexedGraph& graph,
const std::vector<uint32_t>& node_reward,
std::vector<uint32_t>* path) {
const uint32_t num_nodes = static_cast<uint32_t>(graph.num_nodes());
CHECK_EQ(num_nodes, node_reward.size());
std::vector<uint32_t> best_reward(node_reward.size(), 0);
std::vector<uint32_t> next_node(node_reward.size(), num_nodes);
uint32_t best_solution = 0, best_start_node = 0;
// traverse in reverse topo order
for (uint32_t i = static_cast<uint32_t>(graph.num_nodes()); i != 0; --i) {
const uint32_t nid = i - 1;
best_reward[nid] += node_reward[nid];
if (best_reward[nid] > best_solution) {
best_solution = best_reward[nid];
best_start_node = nid;
}
for (const auto& e : graph[nid].inputs) {
const uint32_t prev = e.node_id;
if (best_reward[nid] > best_reward[prev]) {
best_reward[prev] = best_reward[nid];
next_node[prev] = nid;
}
}
}
path->clear();
uint32_t reward = 0;
for (uint32_t nid = best_start_node; nid < num_nodes; nid = next_node[nid]) {
path->push_back(nid); reward += node_reward[nid];
}
CHECK_EQ(reward, best_solution);
return best_solution;
}
/*!
* \brief Color the nodes in the graph into index.
* The coloring algorithm tries to assign node group
* such that node in the same group cannot run in parallel.
*
* \param graph the original indexed graph.
* \param node_importance The importance of the node
* \param max_ncolor maximum number of colors allowed.
* \param color the color index of each of the node.
* \return the total number of colors.
*/
inline uint32_t ColorNodeGroup(
const IndexedGraph &graph,
std::vector<uint32_t> node_importance,
uint32_t max_ncolor,
std::vector<uint32_t> *color) {
CHECK_NE(max_ncolor, 0U);
CHECK_EQ(graph.num_nodes(), node_importance.size());
color->clear();
color->resize(graph.num_nodes(), max_ncolor);
uint32_t cindex;
// greedy algorithm, every time
// find a path with best reward and assign a new color
// All the nodes in the path cannot run in parallel.
for (cindex = 0; cindex < max_ncolor - 1; ++cindex) {
std::vector<uint32_t> path;
uint32_t reward = FindBestPath(graph, node_importance, &path);
if (reward == 0) break;
for (uint32_t nid : path) {
if (node_importance[nid] != 0) {
CHECK_EQ(color->at(nid), max_ncolor);
color->at(nid) = cindex;
// make the importance 0 after color is decided.
node_importance[nid] = 0;
}
}
}
// assign i for rest of the node
for (uint32_t i = 0; i < graph.num_nodes(); ++i) {
if (color->at(i) == max_ncolor) {
color->at(i) = cindex;
}
}
return cindex + 1;
}
} // namespace pass
} // namespace nnvm
#endif // NNVM_PASS_GRAPH_ALGORITHM_H_
//===== EXPANDED: ../nnvm/src/pass/graph_algorithm.h =====
namespace nnvm {
namespace pass {
namespace {
// simple graph based allocator.
class GraphAllocator {
public:
// storage id equals integer.
using StorageID = int;
// bad storage id
static const StorageID kBadStorageID = -1;
// external storage id
static const StorageID kExternalStorageID = -2;
// request a free storage
StorageID Request(int dev_id, int dtype, TShape shape, uint32_t node_id) {
if (shape.ndim() == 0) return kBadStorageID;
// search memory block in [size / match_range_, size * match_range_)
// TODO(tqchen) add size of the dtype, assume 4 bytes for now
size_t size = shape.Size() * 4;
if (match_range_ == 0) return this->Alloc(dev_id, size);
auto begin = free_.lower_bound(size / match_range_);
auto mid = free_.lower_bound(size);
auto end = free_.upper_bound(size * match_range_);
// search for memory blocks larger than requested
for (auto it = mid; it != end; ++it) {
StorageEntry *e = it->second;
if (e->device_id != dev_id) continue;
if (node_color_.size() != 0 &&
node_color_[e->released_by_node] != node_color_[node_id]) continue;
// Use exect matching strategy
e->max_bytes = std::max(size, e->max_bytes);
// find a exact match, erase from map and return
free_.erase(it);
return e->id;
}
// then search for memory blocks smaller than requested space
for (auto it = mid; it != begin;) {
--it;
StorageEntry *e = it->second;
if (e->device_id != dev_id) continue;
if (node_color_.size() != 0 &&
node_color_[e->released_by_node] != node_color_[node_id]) continue;
// Use exect matching strategy
e->max_bytes = std::max(size, e->max_bytes);
// find a exact match, erase from map and return
free_.erase(it);
return e->id;
}
// cannot find anything return a new one.
return this->Alloc(dev_id, size);
}
// release a memory space.
void Release(StorageID id, uint32_t node_id) {
CHECK_NE(id, kBadStorageID);
if (id == kExternalStorageID) return;
StorageEntry *e = data_[id].get();
e->released_by_node = node_id;
free_.insert({e->max_bytes, e});
}
// totoal number of bytes allocated
size_t TotalAllocBytes() const {
size_t total = 0;
for (auto &p : data_) {
total += p->max_bytes;
}
return total;
}
// constructor
explicit GraphAllocator(const IndexedGraph* idx) : idx_(idx) {
this->Init(dmlc::GetEnv("NNVM_EXEC_MATCH_RANGE", 16),
dmlc::GetEnv("NNVM_EXEC_NUM_TEMP", 1));
}
private:
// initialize the graph allocator
void Init(size_t match_range, uint32_t num_match_color) {
match_range_ = match_range;
num_match_color_ = num_match_color;
if (num_match_color_ > 1) {
std::vector<uint32_t> importance(idx_->num_nodes(), 0);
for (uint32_t nid = 0; nid < idx_->num_nodes(); ++nid) {
if ((*idx_)[nid].source->is_variable()) continue;
importance[nid] = 1;
}
num_match_color_ = pass::ColorNodeGroup(
*idx_, importance, num_match_color_, &node_color_);
}
}
StorageID Alloc(int dev_id, size_t size) {
StorageID id = static_cast<StorageID>(data_.size());
std::unique_ptr<StorageEntry> ptr(new StorageEntry());
ptr->id = id;
ptr->device_id = dev_id;
ptr->max_bytes = size;
data_.emplace_back(std::move(ptr));
return id;
}
// internal storage entry
struct StorageEntry {
// the id of the entry.
StorageID id;
// the device id of the storage.
int device_id;
// maximum size of storage requested.
size_t max_bytes{0};
// node index that released it last time
uint32_t released_by_node{0};
};
// scale used for rough match
size_t match_range_;
// whether use color based match algorithm
uint32_t num_match_color_{1};
// the size of each dtype
std::vector<size_t> dtype_size_dict_;
// free list of storage entry
std::multimap<size_t, StorageEntry*> free_;
// all the storage resources available
std::vector<std::unique_ptr<StorageEntry> > data_;
// color of nodes in the graph, used for auxiliary policy making.
std::vector<uint32_t> node_color_;
// internal indexed graph
const IndexedGraph* idx_;
};
// function to plan memory
Graph PlanMemory(Graph ret) {
// setup ref counter
const IndexedGraph& idx = ret.indexed_graph();
static auto& fignore_inputs = Op::GetAttr<FIgnoreInputs>("FIgnoreInputs");
// reference counter of each node
std::vector<uint32_t> ref_count(idx.num_node_entries(), 0);
// step 1: initialize reference count
for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
const auto& inode = idx[nid];
if (inode.source->is_variable()) continue;
for (const auto& e : inode.inputs) {
++ref_count[idx.entry_id(e)];
}
// no dataflow dependency is needed for those are ignored.
// revoke the dependency counter.
if (fignore_inputs.count(inode.source->op()) != 0) {
auto ignore_inputs = fignore_inputs[inode.source->op()](inode.source->attrs);
for (uint32_t i : ignore_inputs) {
--ref_count[idx.entry_id(inode.inputs[i])];
}
}
}
for (const auto& e : idx.outputs()) {
++ref_count[idx.entry_id(e)];
}
// step 2: allocate memory.
StorageVector storage;
if (ret.attrs.count("storage") != 0) {
storage = ret.MoveCopyAttr<StorageVector>("storage");
} else {
storage.resize(idx.num_node_entries(), -1);
}
std::vector<int> storage_inplace_index(idx.num_node_entries(), -1);
const ShapeVector& shape_vec = ret.GetAttr<ShapeVector>("shape");
const DTypeVector& dtype_vec = ret.GetAttr<DTypeVector>("dtype");
const DeviceVector* device_vec = nullptr;
static auto& finplace_option = Op::GetAttr<FInplaceOption>("FInplaceOption");
if (ret.attrs.count("device") != 0) {
device_vec = &(ret.GetAttr<DeviceVector>("device"));
}
// the allocator.
GraphAllocator allocator(&idx);
// number of entries that are not statically allocated.
size_t num_not_allocated = 0;
for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
const auto& inode = idx[nid];
if (inode.source->is_variable()) continue;
// check inplace option
if (finplace_option.count(inode.source->op()) != 0) {
auto inplace_pairs = finplace_option[inode.source->op()](inode.source->attrs);
for (auto& kv : inplace_pairs) {
uint32_t eid_out = idx.entry_id(nid, kv.second);
uint32_t eid_in = idx.entry_id(inode.inputs[kv.first]);
if (ref_count[eid_in] == 1 &&
ref_count[eid_out] != 0 &&
storage[eid_out] == GraphAllocator::kBadStorageID &&
storage[eid_in] != GraphAllocator::kBadStorageID &&
shape_vec[eid_out].Size() == shape_vec[eid_in].Size() &&
dtype_vec[eid_out] == dtype_vec[eid_in]) {
// inplace optimization
storage[eid_out] = storage[eid_in];
ref_count[eid_in] = 0;
storage_inplace_index[eid_out] = kv.first;
}
}
}
// normal allocation
const int dev_id = (device_vec != nullptr) ? device_vec->at(nid) : 0;
// allocate output
for (uint32_t index = 0; index < inode.source->num_outputs(); ++index) {
uint32_t eid = idx.entry_id(nid, index);
if (storage[eid] == GraphAllocator::kBadStorageID) {
storage[eid] = allocator.Request(dev_id, dtype_vec[eid], shape_vec[eid], nid);
}
}
// check if certain inputs is ignored.
std::vector<uint32_t> ignore_inputs;
if (fignore_inputs.count(inode.source->op()) != 0) {
ignore_inputs = fignore_inputs[inode.source->op()](inode.source->attrs);
std::sort(ignore_inputs.begin(), ignore_inputs.end());
}
// then free inputs
for (size_t i = 0; i < inode.inputs.size(); ++i) {
// ref counter of ignored input is already decreased.
if (std::binary_search(ignore_inputs.begin(), ignore_inputs.end(), i)) continue;
const auto& e = inode.inputs[i];
uint32_t eid = idx.entry_id(e);
// temp_ref_count == 0 means it is taken by inplace op
if (ref_count[eid] == 0) continue;
// if we decrease it to zero, means we are ready to relase
--ref_count[eid];
if (ref_count[eid] == 0 && storage[eid] != GraphAllocator::kBadStorageID) {
allocator.Release(storage[eid], nid);
}
}
// check if there are outputs that can be freeded immediately
// these output are not referenced by any operator.
for (uint32_t index = 0; index < inode.source->num_outputs(); ++index) {
uint32_t eid = idx.entry_id(nid, index);
if (ref_count[eid] == 0 && storage[eid] != GraphAllocator::kBadStorageID) {
allocator.Release(storage[eid], nid);
// use -2 to indicate that the node was never touched.
storage_inplace_index[eid] = -2;
}
if (storage[eid] == GraphAllocator::kBadStorageID) {
++num_not_allocated;
}
}
}
ret.attrs["storage_id"] = std::make_shared<any>(std::move(storage));
ret.attrs["storage_inplace_index"] = std::make_shared<any>(std::move(storage_inplace_index));
ret.attrs["storage_allocated_bytes"] = std::make_shared<any>(allocator.TotalAllocBytes());
ret.attrs["storage_num_not_allocated"] = std::make_shared<any>(num_not_allocated);
return ret;
}
NNVM_REGISTER_PASS(PlanMemory)
.describe("Plan the memory allocation of each node entries.")
.set_body(PlanMemory)
.set_change_graph(false)
.depend_graph_attr("dtype")
.depend_graph_attr("shape")
.provide_graph_attr("storage_id")
.provide_graph_attr("storage_inplace_index");
} // namespace
} // namespace pass
} // namespace nnvm
//===== EXPANDED: ../nnvm/src/pass/plan_memory.cc =====
//===== EXPANDING: ../nnvm/src/pass/infer_shape_type.cc =====
/*!
* Copyright (c) 2016 by Contributors
* \file infer_shape.cc
* \brief Inference the shapes given existin information.
*/
namespace nnvm {
namespace pass {
namespace {
template<typename AttrType, typename IsNone, typename FDefault>
Graph InferAttr(Graph &&ret,
const AttrType empty_val,
const char* infer_name,
const char* input_name,
const char* attr_key_name,
const char* attr_name,
const char* unknown_name,
IsNone fis_none,
FDefault fdefault) {
using AttrVector = std::vector<AttrType>;
const IndexedGraph& idx = ret.indexed_graph();
static auto& finfer_shape =
Op::GetAttr<FInferNodeEntryAttr<AttrType> >(infer_name);
static auto& is_backward =
Op::GetAttr<TIsBackward>("TIsBackward");
// gradient function, used to get node correspondence.
static auto& fgrad =
Op::GetAttr<FGradient>("FGradient");
// reshape shape vector
AttrVector rshape;
if (ret.attrs.count(attr_name) != 0) {
rshape = ret.MoveCopyAttr<AttrVector>(attr_name);
} else {
rshape.resize(idx.num_node_entries(), empty_val);
}
if (ret.attrs.count(input_name) != 0) {
const AttrVector& shape_args = ret.GetAttr<AttrVector>(input_name);
CHECK_LE(shape_args.size(), idx.input_nodes().size())
<< "More provided shapes than number of arguments.";
for (size_t i = 0; i < shape_args.size(); ++i) {
rshape[idx.entry_id(idx.input_nodes()[i], 0)] = shape_args[i];
}
// erase the provided arguments
ret.attrs.erase(input_name);
}
std::string shape_attr_key;
if (ret.attrs.count(attr_key_name) != 0) {
shape_attr_key = ret.GetAttr<std::string>(attr_key_name);
// erase the provided arguments
ret.attrs.erase(attr_key_name);
}
// Temp space for shape inference.
std::vector<AttrType> ishape, oshape;
// inference step function for nid
auto infer_step = [&](uint32_t nid, bool last_iter) {
const auto& inode = idx[nid];
const uint32_t num_inputs = inode.inputs.size();
const uint32_t num_outputs = inode.source->num_outputs();
if (inode.source->is_variable()) {
// Variable node. No operator. Only one output entry.
CHECK(inode.source->op() == nullptr);
CHECK_EQ(num_outputs, 1U);
const uint32_t out_ent_id = idx.entry_id(nid, 0);
if (shape_attr_key.length() != 0 && fis_none(rshape[out_ent_id])) {
auto it = inode.source->attrs.dict.find(shape_attr_key);
if (it != inode.source->attrs.dict.end()) {
std::istringstream is(it->second);
CHECK(is >> rshape[out_ent_id]) << "Invalid attribute";
}
}
} else if (is_backward.get(inode.source->op(), false)) {
CHECK_GE(inode.control_deps.size(), 1U)
<< "BackwardOp need to have control_deps to its forward op";
const IndexedGraph::Node& fnode = idx[inode.control_deps[0]];
NodePtr fwd_ptr = inode.source->control_deps[0];
// use gradient function to find out the correspondence.
std::vector<NodeEntry> ograd(fwd_ptr->num_outputs());
for (size_t i = 0; i < ograd.size(); ++i) {
ograd[i].index = static_cast<uint32_t>(i);
}
// input gradient list
auto igrad = fgrad[fwd_ptr->op()](fwd_ptr, ograd);
const Op* backward_op = inode.source->op();
const Node* igrad_node = nullptr;
// Input gradient assignement
for (size_t i = 0; i < igrad.size(); ++i) {
if (igrad[i].node->op() == backward_op) {
uint32_t eid = idx.entry_id(nid, igrad[i].index);
if (fis_none(rshape[eid])) {
rshape[eid] = rshape[idx.entry_id(fnode.inputs[i])];
} else {
CHECK_EQ(rshape[eid], rshape[idx.entry_id(fnode.inputs[i])])
<< "Backward shape inconsistent with the forward shape";
}
if (igrad_node == nullptr) {
igrad_node = igrad[i].node.get();
} else {
CHECK(igrad_node == igrad[i].node.get());
}
}
}
// out grad entries
for (size_t i = 0; i < igrad_node->inputs.size(); ++i) {
const NodeEntry& e = igrad_node->inputs[i];
if (e.node == nullptr) {
uint32_t eid = idx.entry_id(inode.inputs[i]);
if (fis_none(rshape[eid])) {
rshape[eid] = rshape[idx.entry_id(inode.control_deps[0], e.index)];
}
}
}
} else {
bool forward_known = true;
// Forward operator inference.
ishape.resize(num_inputs, empty_val);
for (uint32_t i = 0; i < ishape.size(); ++i) {
ishape[i] = rshape[idx.entry_id(inode.inputs[i])];
if (fis_none(ishape[i])) forward_known = false;
}
oshape.resize(num_outputs, empty_val);
for (uint32_t i = 0; i < oshape.size(); ++i) {
oshape[i] = rshape[idx.entry_id(nid, i)];
if (fis_none(oshape[i])) forward_known = false;
}
auto finfer = finfer_shape.get(inode.source->op(), fdefault);
if (!forward_known) {
if (finfer != nullptr) {
// Call inference function of the operator.
try {
forward_known = finfer(inode.source->attrs, &ishape, &oshape);
} catch (const std::exception& e) {
throw dmlc::Error("Error in operator " + inode.source->attrs.name + ": " + e.what());
}
} else {
CHECK(!last_iter)
<< "Attribute " << infer_name
<< " is not registed by op " << inode.source->op()->name
<< " we are not able to complete the inference because of this";
}
}
// Save to the result map.
for (uint32_t i = 0; i < num_inputs; ++i) {
rshape[idx.entry_id(inode.inputs[i])] = ishape[i];
}
for (uint32_t i = 0; i < num_outputs; ++i) {
rshape[idx.entry_id(nid, i)] = oshape[i];
}
}
};
size_t last_num_unknown;
size_t num_unknown = rshape.size();
int i = 0;
do {
if (i % 2 == 0) {
for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
infer_step(nid, false);
}
} else {
// backward inference
for (uint32_t i = idx.num_nodes(); i != 0; --i) {
infer_step(i - 1, false);
}
}
last_num_unknown = num_unknown;
num_unknown = 0;
for (size_t j = 0; j < idx.num_node_entries(); ++j) {
if (fis_none(rshape[j])) {
++num_unknown;
}
}
++i;
} while (num_unknown > 0 && last_num_unknown > num_unknown);
// set the shapes
ret.attrs[attr_name] = std::make_shared<any>(std::move(rshape));
// number of nodes who knows the shape.
ret.attrs[unknown_name] = std::make_shared<any>(num_unknown);
return ret;
}
NNVM_REGISTER_PASS(InferShape)
.describe("Infer the shape of each node entries.")
.set_body([](Graph ret) {
return InferAttr<TShape>(
std::move(ret), TShape(),
"FInferShape", "shape_inputs", "shape_attr_key",
"shape", "shape_num_unknown_nodes",
[](const TShape& s) { return s.ndim() == 0 || s.Size() == 0; },
nullptr);
})
.set_change_graph(false)
.provide_graph_attr("shape");
// inference fucntion for same type
inline bool SameType(const NodeAttrs& attrs,
std::vector<int> *iattr,
std::vector<int> *oattr) {
int def_v = -1;
for (int v : *oattr) {
if (v != -1) {
def_v = v; break;
}
}
if (def_v == -1) {
for (int v : *iattr) {
if (v != -1) {
def_v = v; break;
}
}
}
if (def_v == -1) return false;
for (int& v : *oattr) {
v = def_v;
}
for (int& v : *iattr) {
v = def_v;
}
return true;
}
NNVM_REGISTER_PASS(InferType)
.describe("Infer the dtype of each node entries.")
.set_body([](Graph ret) {
return InferAttr<int>(
std::move(ret), -1,
"FInferType", "dtype_inputs", "dtype_attr_key",
"dtype", "dtype_num_unknown_nodes",
[](const int t) { return t == -1; },
SameType);
})
.set_change_graph(false)
.provide_graph_attr("dtype");
DMLC_JSON_ENABLE_ANY(ShapeVector, list_shape);
DMLC_JSON_ENABLE_ANY(DTypeVector, list_int);
DMLC_JSON_ENABLE_ANY(size_t, size_t);
} // namespace
} // namespace pass
} // namespace nnvm
//===== EXPANDED: ../nnvm/src/pass/infer_shape_type.cc =====
//===== EXPANDING: ../nnvm/src/pass/place_device.cc =====
/*!
* Copyright (c) 2016 by Contributors
* \file place_device.cc
* \brief Inference the device of each operator given known information.
* Insert a copy node automatically when there is a cross device.
*/
namespace nnvm {
namespace pass {
namespace {
// simply logic to place device according to device_group hint
// insert copy node when there is
Graph PlaceDevice(Graph src) {
CHECK(src.attrs.count("device_group_attr_key"))
<< "Need graph attribute \"device_group_attr_key\" in PlaceDevice";
CHECK(src.attrs.count("device_assign_map"))
<< "Need graph attribute \"device_assign_map\" in PlaceDevice";
CHECK(src.attrs.count("device_copy_op"))
<< "Need graph attribute \"device_copy_op\" in PlaceDevice";
std::string device_group_attr_key = src.GetAttr<std::string>("device_group_attr_key");
const Op* copy_op = Op::Get(src.GetAttr<std::string>("device_copy_op"));
auto& device_assign_map = src.GetAttr<DeviceAssignMap>("device_assign_map");
const IndexedGraph& idx = src.indexed_graph();
static auto& is_backward =
Op::GetAttr<TIsBackward>("TIsBackward");
DeviceVector device;
// copy on write semanatics
if (src.attrs.count("device") != 0) {
device = src.MoveCopyAttr<DeviceVector>("device");
CHECK_EQ(device.size(), idx.num_nodes());
} else {
device.resize(idx.num_nodes(), -1);
}
// forward pass
for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
const auto& inode = idx[nid];
auto it = inode.source->attrs.dict.find(device_group_attr_key);
if (it != inode.source->attrs.dict.end()) {
const std::string& device_group = it->second;
auto dit = device_assign_map.find(device_group);
CHECK(dit != device_assign_map.end())
<< "The device assignment not found for group " << device_group;
device[nid] = dit->second;
} else {
if (!inode.source->is_variable() &&
is_backward.get(inode.source->op(), false)) {
if (device[inode.control_deps[0]] != -1) {
device[nid] = device[inode.control_deps[0]];
}
} else {
for (const IndexedGraph::NodeEntry& e : inode.inputs) {
if (device[e.node_id] != -1) {
device[nid] = device[e.node_id]; break;
}
}
}
}
}
// backward pass
for (uint32_t i = idx.num_nodes(); i != 0; --i) {
uint32_t nid = i - 1;
const auto& inode = idx[nid];
if (device[nid] == -1) continue;
for (const IndexedGraph::NodeEntry& e : inode.inputs) {
if (device[e.node_id] == -1) device[e.node_id] = device[nid];
}
}
int num_dev = 1, other_dev_id = -1;
for (int& dev : device) {
if (dev == -1) dev = 0;
if (dev != other_dev_id) {
if (other_dev_id != -1) ++num_dev;
other_dev_id = dev;
}
}
if (num_dev == 1) {
src.attrs.erase("device_group_attr_key");
src.attrs.erase("device_assign_map");
src.attrs.erase("device_copy_op");
src.attrs["device"] = std::make_shared<any>(std::move(device));
return src;
}
std::map<std::tuple<uint32_t, uint32_t, int>, NodePtr> copy_map;
std::vector<NodePtr> new_node_map(idx.num_nodes(), nullptr);
std::unordered_map<const Node*, int> new_device_map;
static auto& fmutate_inputs = Op::GetAttr<FMutateInputs>("FMutateInputs");
// insert copy node
for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
int dev_id = device[nid];
const auto& inode = idx[nid];
// check if mutation is needed
bool need_mutate = false;
if (!inode.source->is_variable() && fmutate_inputs.count(inode.source->op())) {
for (uint32_t index : fmutate_inputs[inode.source->op()](inode.source->attrs)) {
auto e = inode.inputs[index];
if (new_node_map[e.node_id] != nullptr || dev_id != device[e.node_id]) {
LOG(FATAL) << " mutable state cannot go across device"
<< " op=" << inode.source->op()->name
<< " input_state_index=" << index;
}
}
}
for (const IndexedGraph::NodeEntry& e : inode.inputs) {
if (new_node_map[e.node_id] != nullptr || dev_id != device[e.node_id]) {
need_mutate = true; break;
}
}
if (!need_mutate) {
for (const uint32_t cid : inode.control_deps) {
if (new_node_map[cid] != nullptr) {
need_mutate = true; break;
}
}
}
if (inode.source->is_variable()) {
CHECK(!need_mutate) << "consistency check";
}
if (need_mutate) {
NodePtr new_node = Node::Create();
new_node->attrs = inode.source->attrs;
new_node->inputs.reserve(inode.inputs.size());
for (size_t i = 0; i < inode.inputs.size(); ++i) {
const IndexedGraph::NodeEntry& e = inode.inputs[i];
if (dev_id != device[e.node_id]) {
auto copy_key = std::make_tuple(e.node_id, e.index, dev_id);
auto it = copy_map.find(copy_key);
if (it != copy_map.end() && it->first == copy_key) {
new_node->inputs.emplace_back(
NodeEntry{it->second, 0, 0});
} else {
NodePtr copy_node = Node::Create();
std::ostringstream os;
os << inode.source->inputs[i].node->attrs.name << "_" << e.index <<"_copy";
copy_node->attrs.op = copy_op;
copy_node->attrs.name = os.str();
if (new_node_map[e.node_id] != nullptr) {
copy_node->inputs.emplace_back(
NodeEntry{new_node_map[e.node_id], e.index, 0});
} else {
copy_node->inputs.push_back(inode.source->inputs[i]);
}
if (copy_node->attrs.op->attr_parser != nullptr) {
copy_node->attrs.op->attr_parser(&(copy_node->attrs));
}
copy_map[copy_key] = copy_node;
new_device_map[copy_node.get()] = dev_id;
new_node->inputs.emplace_back(
NodeEntry{std::move(copy_node), 0, 0});
}
} else {
if (new_node_map[e.node_id] != nullptr) {
new_node->inputs.emplace_back(
NodeEntry{new_node_map[e.node_id], e.index, 0});
} else {
new_node->inputs.push_back(inode.source->inputs[i]);
}
}
}
new_node->control_deps.reserve(inode.control_deps.size());
for (size_t i = 0; i < inode.control_deps.size(); ++i) {
uint32_t cid = inode.control_deps[i];
if (new_node_map[cid] != nullptr) {
new_node->control_deps.push_back(new_node_map[cid]);
} else {
new_node->control_deps.push_back(inode.source->control_deps[i]);
}
}
new_device_map[new_node.get()] = dev_id;
new_node_map[nid] = std::move(new_node);
} else {
new_device_map[inode.source] = dev_id;
}
}
// make the new graph
Graph ret;
for (const NodeEntry& e : src.outputs) {
if (new_node_map[idx.node_id(e.node.get())] != nullptr) {
ret.outputs.emplace_back(
NodeEntry{new_node_map[idx.node_id(e.node.get())], e.index, e.version});
} else {
ret.outputs.emplace_back(e);
}
}
DeviceVector new_device_vec(ret.indexed_graph().num_nodes());
for (uint32_t nid = 0; nid < ret.indexed_graph().num_nodes(); ++nid) {
auto source = ret.indexed_graph()[nid].source;
if (new_device_map.count(source) == 0) {
LOG(FATAL) << "canot find " << source;
}
new_device_vec[nid] = new_device_map.at(source);
}
ret.attrs["device"] = std::make_shared<any>(std::move(new_device_vec));
return ret;
}
NNVM_REGISTER_PASS(PlaceDevice)
.describe("Infer the device type of each operator."\
"Insert a copy node when there is cross device copy")
.set_body(PlaceDevice)
.set_change_graph(true)
.provide_graph_attr("device")
.depend_graph_attr("device_group_attr_key")
.depend_graph_attr("device_assign_map")
.depend_graph_attr("device_copy_op");
DMLC_JSON_ENABLE_ANY(DeviceAssignMap, dict_str_int);
} // namespace
} // namespace pass
} // namespace nnvm
//===== EXPANDED: ../nnvm/src/pass/place_device.cc =====
//===== EXPANDING: ../nnvm/src/pass/saveload_json.cc =====
/*!
* Copyright (c) 2016 by Contributors
* \file saveload_json.cc
* \brief Save and load graph to/from JSON file.
*/
//===== EXPANDING: ../nnvm/include/nnvm/pass_functions.h =====
/*!
* Copyright (c) 2016 by Contributors
* \file pass_functions.h
* \brief Pass functions that simply redirect the calls to ApplyPass
*
* This file serves as documentation on how to use functions implemented in "src/pass".
* It is totally optional to add these functions when you add a new pass, since
* ApplyPass can be directly called.
*/
#ifndef NNVM_PASS_FUNCTIONS_H_
#define NNVM_PASS_FUNCTIONS_H_
namespace nnvm {
namespace pass {
/*!
* \brief Load a graph from JSON string, redirects to "LoadJSON" pass.
* \param json_str The json string.
* \return Loaded graph.
*/
inline Graph LoadJSON(const std::string& json_str) {
Graph ret;
ret.attrs["json"] = std::make_shared<any>(json_str);
return ApplyPass(ret, "LoadJSON");
}
/*!
* \brief Save a graph to json, redirects to "SaveJSON" pass.
* \param graph The graph to be saved as json format.
* \return The json string.
*/
inline std::string SaveJSON(Graph graph) {
Graph ret = ApplyPass(std::move(graph), "SaveJSON");
return ret.GetAttr<std::string>("json");
}
/*!
* \brief Add control flow dependencies between nodes.
*
* This function will enforce the correct order between
* write (mutable operators) and read (immutable operators)
* to sovle write-after-read and read-after-write problems.
*
* \param src The input graph.
* \return A graph with proper control flow dependencies added.
*/
inline Graph OrderMutation(Graph src) {
return ApplyPass(std::move(src), "OrderMutation");
}
/*!
* \brief Infer shapes in the graph given the information.
* \param graph The input graph.
* \param shape_inputs The shapes of input symbols to the graph.
* \param shape_attr_key The key to the node attribute that can indicate shape. This is
* the place where manual hint for shapes could be injected.
* \return A graph with new attribute "shape" containing inferred shape of each NodeEntry.
* The index of ShapeVector is given by graph.indexed_graph().entry_id.
*/
inline Graph InferShape(Graph graph,
ShapeVector shape_inputs,
std::string shape_attr_key = "") {
if (shape_inputs.size() != 0) {
graph.attrs["shape_inputs"] = std::make_shared<any>(std::move(shape_inputs));
}
if (shape_attr_key.length() != 0) {
graph.attrs["shape_attr_key"] = std::make_shared<any>(std::move(shape_attr_key));
}
return ApplyPass(std::move(graph), "InferShape");
}
/*!
* \brief Infer types in the graph given the information.
* \param graph The input graph.
* \param dtype_inputs The types of input symbols to the graph.
* \param dtype_attr_key The key to the node attribute that can indicate types. This is
* the place where manual hint for types could be injected.
* \return A graph with new attribute "dtype" containing inferred type of each NodeEntry.
* The index of ShapeVector is given by graph.indexed_graph().entry_id.
*/
inline Graph InferType(Graph graph,
DTypeVector dtype_inputs,
std::string dtype_attr_key = "") {
if (dtype_inputs.size() != 0) {
graph.attrs["dtype_inputs"] = std::make_shared<any>(std::move(dtype_inputs));
}
if (dtype_attr_key.length() != 0) {
graph.attrs["dtype_attr_key"] = std::make_shared<any>(std::move(dtype_attr_key));
}
return ApplyPass(std::move(graph), "InferType");
}
/*!
* \brief Place the devices for each operator in the graph.
*
* Current device placement is quite simple. Each operator is assigned to a "group" (stored
* in `device_group_attr_key` attribute). Each group is assigned to a device (stored in
* `device_assign_map` attribute). Operators will be placed to the device assigned to its
* group. Copy operators will be injected if cross device reference happens.
*
* \param graph The input graph.
* \param device_group_attr_key The attribute name for hints of device group.
* \param device_assign_map The assignment map of device.
* \param device_copy_op The name of copy op to be inserted when cross device copy happened.
* \return A graph with new attribute "device", cotaining device information of each node.
*/
inline Graph PlaceDevice(Graph graph,
std::string device_group_attr_key,
DeviceAssignMap device_assign_map,
std::string device_copy_op) {
graph.attrs["device_group_attr_key"] = std::make_shared<any>(std::move(device_group_attr_key));
graph.attrs["device_assign_map"] = std::make_shared<any>(std::move(device_assign_map));
graph.attrs["device_copy_op"] = std::make_shared<any>(std::move(device_copy_op));
return ApplyPass(std::move(graph), "PlaceDevice");
}
/*!
* \brief Get the gradient graph whose outputs are gradients of xs wrt to ys.
* \param graph The input graph.
* \param ys The entries we want to take gradient from.
* \param xs The input to take gradient with respect to.
* \param ys_out_grad The symbol for additional gradient to be propagate back to y.
* \param aggregate_fun Aggregation function applied to aggregate the inputs.
* \param mirror_fun Optional mirror function to do mirror optimization and save memory.
* \param attr_hint_fun Optional, hint function to output a node that like src, but its attr is same as like.
* \return A new graph, whose outputs correspond to inputs of xs.
*/
inline Graph Gradient(
Graph graph,
std::vector<NodeEntry> ys,
std::vector<NodeEntry> xs,
std::vector<NodeEntry> ys_out_grad,
std::function<NodeEntry(std::vector<NodeEntry>&& inputs)> aggregate_fun = nullptr,
std::function<int(const Node& node)> mirror_fun = nullptr,
std::function<NodeEntry(const NodeEntry& src, const NodeEntry &like)>
attr_hint_fun = nullptr) {
graph.attrs["grad_ys"] = std::make_shared<any>(std::move(ys));
graph.attrs["grad_xs"] = std::make_shared<any>(std::move(xs));
graph.attrs["grad_ys_out_grad"] = std::make_shared<any>(std::move(ys_out_grad));
if (aggregate_fun != nullptr) {
graph.attrs["grad_aggregate_fun"] = std::make_shared<any>(aggregate_fun);
}
if (mirror_fun != nullptr) {
graph.attrs["grad_mirror_fun"] = std::make_shared<any>(mirror_fun);
}
if (attr_hint_fun != nullptr) {
graph.attrs["attr_hint_fun"] = std::make_shared<any>(attr_hint_fun);
}
return ApplyPass(std::move(graph), "Gradient");
}
} // namespace pass
} // namespace nnvm
#endif // NNVM_PASS_FUNCTIONS_H_
//===== EXPANDED: ../nnvm/include/nnvm/pass_functions.h =====
namespace dmlc {
namespace json {
// overload handler for shared ptr
template<>
struct Handler<std::shared_ptr<any> > {
inline static void Write(JSONWriter *writer, const std::shared_ptr<any> &data) {
writer->Write(*data);
}
inline static void Read(JSONReader *reader, std::shared_ptr<any> *data) {
any v;
reader->Read(&v);
*data = std::make_shared<any>(std::move(v));
}
};
} // namespace json
} // namespace dmlc
namespace nnvm {
namespace pass {
namespace {
// auxiliary node structure for serialization.
struct JSONNode {
// the node entry structure in serialized format
struct Entry {
uint32_t node_id;
uint32_t index;
uint32_t version;
void Save(dmlc::JSONWriter *writer) const {
writer->BeginArray(false);
writer->WriteArrayItem(node_id);
writer->WriteArrayItem(index);
writer->WriteArrayItem(version);
writer->EndArray();
}
void Load(dmlc::JSONReader *reader) {
reader->BeginArray();
CHECK(reader->NextArrayItem()) << "invalid json format";
reader->Read(&node_id);
CHECK(reader->NextArrayItem()) << "invalid json format";
reader->Read(&index);
if (reader->NextArrayItem()) {
reader->Read(&version);
CHECK(!reader->NextArrayItem()) << "invalid json format";
} else {
version = 0;
}
}
};
// pointer to the graph node
NodePtr node;
// inputs
std::vector<Entry> inputs;
// control flow dependencies
std::vector<uint32_t> control_deps;
// function to save JSON node.
void Save(dmlc::JSONWriter *writer) const {
writer->BeginObject();
if (node->op() != nullptr) {
writer->WriteObjectKeyValue("op", node->op()->name);
} else {
std::string json_null = "null";
writer->WriteObjectKeyValue("op", json_null);
}
writer->WriteObjectKeyValue("name", node->attrs.name);
if (node->attrs.dict.size() != 0) {
// write attributes in order;
std::map<std::string, std::string> dict(
node->attrs.dict.begin(), node->attrs.dict.end());
writer->WriteObjectKeyValue("attr", dict);
}
writer->WriteObjectKeyValue("inputs", inputs);
if (control_deps.size() != 0) {
writer->WriteObjectKeyValue("control_deps", control_deps);
}
writer->EndObject();
}
void Load(dmlc::JSONReader *reader) {
node = Node::Create();
control_deps.clear();
dmlc::JSONObjectReadHelper helper;
std::string op_type_str;
helper.DeclareField("op", &op_type_str);
helper.DeclareField("name", &(node->attrs.name));
helper.DeclareField("inputs", &inputs);
helper.DeclareOptionalField("attr", &(node->attrs.dict));
helper.DeclareOptionalField("control_deps", &control_deps);
// backward compatible code with mxnet graph.
int backward_source_id;
std::unordered_map<std::string, std::string> param;
helper.DeclareOptionalField("param", &param);
helper.DeclareOptionalField("backward_source_id", &backward_source_id);
helper.ReadAllFields(reader);
node->attrs.dict.insert(param.begin(), param.end());
if (op_type_str != "null") {
try {
node->attrs.op = Op::Get(op_type_str);
} catch (const dmlc::Error &err) {
std::ostringstream os;
os << "Failed loading Op " << node->attrs.name
<< " of type " << op_type_str << ": " << err.what();
throw dmlc::Error(os.str());
}
} else {
node->attrs.op = nullptr;
}
}
};
// graph structure to help read/save JSON.
struct JSONGraph {
std::vector<JSONNode> nodes;
std::vector<uint32_t> arg_nodes;
std::vector<uint32_t> node_row_ptr;
std::vector<JSONNode::Entry> heads;
std::unordered_map<std::string, std::shared_ptr<any> > attrs;
void Save(dmlc::JSONWriter *writer) const {
writer->BeginObject();
writer->WriteObjectKeyValue("nodes", nodes);
writer->WriteObjectKeyValue("arg_nodes", arg_nodes);
writer->WriteObjectKeyValue("node_row_ptr", node_row_ptr);
writer->WriteObjectKeyValue("heads", heads);
if (attrs.size() != 0) {
writer->WriteObjectKeyValue("attrs", attrs);
}
writer->EndObject();
}
void Load(dmlc::JSONReader *reader) {
attrs.clear();
dmlc::JSONObjectReadHelper helper;
helper.DeclareField("nodes", &nodes);
helper.DeclareField("arg_nodes", &arg_nodes);
helper.DeclareField("heads", &heads);
helper.DeclareOptionalField("node_row_ptr", &node_row_ptr);
helper.DeclareOptionalField("attrs", &attrs);
helper.ReadAllFields(reader);
}
};
// Load a graph from JSON file.
Graph LoadJSON(Graph src) {
CHECK_NE(src.attrs.count("json"), 0U)
<< "Load JSON require json to be presented.";
const std::string &json_str =
nnvm::get<std::string>(*src.attrs.at("json"));
bool no_parse = false;
if (src.attrs.count("load_json_no_parse")) {
no_parse = nnvm::get<bool>(*src.attrs.at("load_json_no_parse"));
}
std::istringstream is(json_str);
dmlc::JSONReader reader(&is);
JSONGraph jgraph;
// load in json graph.
jgraph.Load(&reader);
// connects the nodes
for (JSONNode &n : jgraph.nodes) {
n.node->inputs.reserve(n.inputs.size());
for (const JSONNode::Entry &e : n.inputs) {
n.node->inputs.emplace_back(
NodeEntry{jgraph.nodes[e.node_id].node, e.index, e.version});
}
n.node->control_deps.reserve(n.control_deps.size());
for (uint32_t nid : n.control_deps) {
n.node->control_deps.push_back(jgraph.nodes[nid].node);
}
// rebuild attribute parser
if (!no_parse && n.node->op() != nullptr &&
n.node->op()->attr_parser != nullptr) {
n.node->op()->attr_parser(&(n.node->attrs));
}
}
// consistent check
for (uint32_t nid : jgraph.arg_nodes) {
CHECK(jgraph.nodes[nid].node->is_variable());
}
// return the graph
Graph ret;
ret.attrs = std::move(jgraph.attrs);
ret.outputs.reserve(jgraph.heads.size());
for (const JSONNode::Entry &e : jgraph.heads) {
ret.outputs.emplace_back(
NodeEntry{jgraph.nodes[e.node_id].node, e.index, e.version});
}
return ret;
}
// save a graph to json
Graph SaveJSON(Graph src) {
JSONGraph jgraph;
jgraph.attrs = src.attrs;
std::unordered_map<Node*, uint32_t> node2index;
jgraph.node_row_ptr.push_back(0);
DFSVisit(src.outputs, [&node2index, &jgraph](const NodePtr& n) {
uint32_t nid = static_cast<uint32_t>(jgraph.nodes.size());
node2index[n.get()] = nid;
if (n->is_variable()) {
jgraph.arg_nodes.push_back(nid);
}
JSONNode jnode;
jnode.node = n;
jnode.inputs.reserve(n->inputs.size());
for (const NodeEntry& e : n->inputs) {
jnode.inputs.emplace_back(
JSONNode::Entry{node2index.at(e.node.get()), e.index, e.version});
}
for (const NodePtr& c : n->control_deps) {
jnode.control_deps.push_back(node2index.at(c.get()));
}
jgraph.node_row_ptr.push_back(
jgraph.node_row_ptr.back() + n->num_outputs());
jgraph.nodes.emplace_back(std::move(jnode));
});
for (const NodeEntry& e : src.outputs) {
jgraph.heads.push_back(
JSONNode::Entry{node2index.at(e.node.get()), e.index, e.version});
}
std::ostringstream os;
dmlc::JSONWriter writer(&os);
jgraph.Save(&writer);
Graph ret;
ret.attrs["json"] = std::make_shared<any>(os.str());
return ret;
}
// register pass
NNVM_REGISTER_PASS(LoadJSON)
.describe("Return a new Graph, loaded from src.attrs[\"json\"]")
.set_body(LoadJSON)
.set_change_graph(true)
.depend_graph_attr("json");
NNVM_REGISTER_PASS(SaveJSON)
.describe("Return a new empty Graph. Save graph to ret.attrs[\"json\"]")
.set_body(SaveJSON)
.set_change_graph(true)
.provide_graph_attr("json");
DMLC_JSON_ENABLE_ANY(std::string, str);
DMLC_JSON_ENABLE_ANY(std::vector<int>, list_int);
} // namespace
} // namespace pass
} // namespace nnvm
//===== EXPANDED: ../nnvm/src/pass/saveload_json.cc =====
//===== EXPANDING: ../nnvm/src/c_api/c_api_error.cc =====
/*!
* Copyright (c) 2016 by Contributors
* \file c_api_error.cc
* \brief C error handling
*/
//===== EXPANDING: ../nnvm/src/c_api/c_api_common.h =====
/*!
* Copyright (c) 2016 by Contributors
* \file c_api_error.h
* \brief Common fields of all C APIs
*/
#ifndef NNVM_C_API_C_API_COMMON_H_
#define NNVM_C_API_C_API_COMMON_H_
//===== EXPANDING: ../nnvm/include/nnvm/c_api.h =====
/*!
* Copyright (c) 2016 by Contributors
* \file c_api.h
* \brief C API of NNVM symbolic construction and pass.
* Enables construction and transformation of Graph
* in any other host languages.
*/
#ifndef NNVM_C_API_H_
#define NNVM_C_API_H_
#ifdef __cplusplus
#define NNVM_EXTERN_C extern "C"
#else
#define NNVM_EXTERN_C
#endif
/*! \brief NNVM_DLL prefix for windows */
#ifdef _WIN32
#ifdef NNVM_EXPORTS
#define NNVM_DLL NNVM_EXTERN_C __declspec(dllexport)
#else
#define NNVM_DLL NNVM_EXTERN_C __declspec(dllimport)
#endif
#else
#define NNVM_DLL NNVM_EXTERN_C
#endif
/*! \brief manually define unsigned int */
typedef unsigned int nn_uint;
/*! \brief handle to a function that takes param and creates symbol */
typedef void *OpHandle;
/*! \brief handle to a symbol that can be bind as operator */
typedef void *SymbolHandle;
/*! \brief handle to Graph */
typedef void *GraphHandle;
/*!
* \brief Set the last error message needed by C API
* \param msg The error message to set.
*/
NNVM_DLL void NNAPISetLastError(const char* msg);
/*!
* \brief return str message of the last error
* all function in this file will return 0 when success
* and -1 when an error occured,
* NNGetLastError can be called to retrieve the error
*
* this function is threadsafe and can be called by different thread
* \return error info
*/
NNVM_DLL const char *NNGetLastError(void);
/*!
* \brief list all the available operator names, include entries.
* \param out_size the size of returned array
* \param out_array the output operator name array.
* \return 0 when success, -1 when failure happens
*/
NNVM_DLL int NNListAllOpNames(nn_uint *out_size,
const char*** out_array);
/*!
* \brief Get operator handle given name.
* \param op_name The name of the operator.
* \param op_out The returnning op handle.
*/
NNVM_DLL int NNGetOpHandle(const char* op_name,
OpHandle* op_out);
/*!
* \brief list all the available operators.
* This won't include the alias, use ListAllNames
* instead to get all alias names.
*
* \param out_size the size of returned array
* \param out_array the output AtomicSymbolCreator array
* \return 0 when success, -1 when failure happens
*/
NNVM_DLL int NNListUniqueOps(nn_uint *out_size,
OpHandle **out_array);
/*!
* \brief Get the detailed information about atomic symbol.
* \param op The operator handle.
* \param real_name The returned name of the creator.
* This name is not the alias name of the atomic symbol.
* \param description The returned description of the symbol.
* \param num_doc_args Number of arguments that contain documents.
* \param arg_names Name of the arguments of doc args
* \param arg_type_infos Type informations about the arguments.
* \param arg_descriptions Description information about the arguments.
* \param return_type Return type of the function, if any.
* \return 0 when success, -1 when failure happens
*/
NNVM_DLL int NNGetOpInfo(OpHandle op,
const char **real_name,
const char **description,
nn_uint *num_doc_args,
const char ***arg_names,
const char ***arg_type_infos,
const char ***arg_descriptions,
const char **return_type);
/*!
* \brief Create an AtomicSymbol functor.
* \param op The operator handle
* \param num_param the number of parameters
* \param keys the keys to the params
* \param vals the vals of the params
* \param out pointer to the created symbol handle
* \return 0 when success, -1 when failure happens
*/
NNVM_DLL int NNSymbolCreateAtomicSymbol(OpHandle op,
nn_uint num_param,
const char **keys,
const char **vals,
SymbolHandle *out);
/*!
* \brief Create a Variable Symbol.
* \param name name of the variable
* \param out pointer to the created symbol handle
* \return 0 when success, -1 when failure happens
*/
NNVM_DLL int NNSymbolCreateVariable(const char *name, SymbolHandle *out);
/*!
* \brief Create a Symbol by grouping list of symbols together
* \param num_symbols number of symbols to be grouped
* \param symbols array of symbol handles
* \param out pointer to the created symbol handle
* \return 0 when success, -1 when failure happens
*/
NNVM_DLL int NNSymbolCreateGroup(nn_uint num_symbols,
SymbolHandle *symbols,
SymbolHandle *out);
/*!
* \brief Add src_dep to the handle as control dep.
* \param handle The symbol to add dependency edges on.
* \param src_dep the source handles.
*/
NNVM_DLL int NNAddControlDeps(SymbolHandle handle,
SymbolHandle src_dep);
/*!
* \brief Free the symbol handle.
* \param symbol the symbol
* \return 0 when success, -1 when failure happens
*/
NNVM_DLL int NNSymbolFree(SymbolHandle symbol);
/*!
* \brief Copy the symbol to another handle
* \param symbol the source symbol
* \param out used to hold the result of copy
* \return 0 when success, -1 when failure happens
*/
NNVM_DLL int NNSymbolCopy(SymbolHandle symbol, SymbolHandle *out);
/*!
* \brief Print the content of symbol, used for debug.
* \param symbol the symbol
* \param out_str pointer to hold the output string of the printing.
* \return 0 when success, -1 when failure happens
*/
NNVM_DLL int NNSymbolPrint(SymbolHandle symbol, const char **out_str);
/*!
* \brief Get string attribute from symbol
* \param symbol the source symbol
* \param key The key of the symbol.
* \param out The result attribute, can be NULL if the attribute do not exist.
* \param success Whether the result is contained in out.
* \return 0 when success, -1 when failure happens
*/
NNVM_DLL int NNSymbolGetAttr(SymbolHandle symbol,
const char* key,
const char** out,
int *success);
/*!
* \brief Set string attribute from symbol.
* NOTE: Setting attribute to a symbol can affect the semantics(mutable/immutable) of symbolic graph.
*
* Safe recommendaton: use immutable graph
* - Only allow set attributes during creation of new symbol as optional parameter
*
* Mutable graph (be careful about the semantics):
* - Allow set attr at any point.
* - Mutating an attribute of some common node of two graphs can cause confusion from user.
*
* \param symbol the source symbol
* \param num_param Number of parameters to set.
* \param keys The keys of the attribute
* \param values The value to be set
* \return 0 when success, -1 when failure happens
*/
NNVM_DLL int NNSymbolSetAttrs(SymbolHandle symbol,
nn_uint num_param,
const char** keys,
const char** values);
/*!
* \brief Get all attributes from symbol, including all descendents.
* \param symbol the source symbol
* \param recursive_option 0 for recursive, 1 for shallow.
* \param out_size The number of output attributes
* \param out 2*out_size strings representing key value pairs.
* \return 0 when success, -1 when failure happens
*/
NNVM_DLL int NNSymbolListAttrs(SymbolHandle symbol,
int recursive_option,
nn_uint *out_size,
const char*** out);
/*!
* \brief List inputs variables in the symbol.
* \param symbol the symbol
* \param option The option to list the inputs
* option=0 means list all arguments.
* option=1 means list arguments that are readed only by the graph.
* option=2 means list arguments that are mutated by the graph.
* \param out_size output size
* \param out_sym_array the output array.
* \return 0 when success, -1 when failure happens
*/
NNVM_DLL int NNSymbolListInputVariables(SymbolHandle symbol,
int option,
nn_uint *out_size,
SymbolHandle** out_sym_array);
/*!
* \brief List input names in the symbol.
* \param symbol the symbol
* \param option The option to list the inputs
* option=0 means list all arguments.
* option=1 means list arguments that are readed only by the graph.
* option=2 means list arguments that are mutated by the graph.
* \param out_size output size
* \param out_str_array pointer to hold the output string array
* \return 0 when success, -1 when failure happens
*/
NNVM_DLL int NNSymbolListInputNames(SymbolHandle symbol,
int option,
nn_uint *out_size,
const char ***out_str_array);
/*!
* \brief List returns names in the symbol.
* \param symbol the symbol
* \param out_size output size
* \param out_str_array pointer to hold the output string array
* \return 0 when success, -1 when failure happens
*/
NNVM_DLL int NNSymbolListOutputNames(SymbolHandle symbol,
nn_uint *out_size,
const char ***out_str_array);
/*!
* \brief Get a symbol that contains all the internals.
* \param symbol The symbol
* \param out The output symbol whose outputs are all the internals.
* \return 0 when success, -1 when failure happens
*/
NNVM_DLL int NNSymbolGetInternals(SymbolHandle symbol,
SymbolHandle *out);
/*!
* \brief Get index-th outputs of the symbol.
* \param symbol The symbol
* \param index the Index of the output.
* \param out The output symbol whose outputs are the index-th symbol.
* \return 0 when success, -1 when failure happens
*/
NNVM_DLL int NNSymbolGetOutput(SymbolHandle symbol,
nn_uint index,
SymbolHandle *out);
/*!
* \brief Compose the symbol on other symbols.
*
* This function will change the sym hanlde.
* To achieve function apply behavior, copy the symbol first
* before apply.
*
* \param sym the symbol to apply
* \param name the name of symbol
* \param num_args number of arguments
* \param keys the key of keyword args (optional)
* \param args arguments to sym
* \return 0 when success, -1 when failure happens
*/
NNVM_DLL int NNSymbolCompose(SymbolHandle sym,
const char* name,
nn_uint num_args,
const char** keys,
SymbolHandle* args);
// Graph IR API
/*!
* \brief create a graph handle from symbol
* \param symbol The symbol representing the graph.
* \param graph The graph handle created.
* \return 0 when success, -1 when failure happens
*/
NNVM_DLL int NNGraphCreate(SymbolHandle symbol, GraphHandle *graph);
/*!
* \brief free the graph handle
* \param handle The handle to be freed.
*/
NNVM_DLL int NNGraphFree(GraphHandle handle);
/*!
* \brief Get a new symbol from the graph.
* \param graph The graph handle.
* \param symbol The corresponding symbol
* \return 0 when success, -1 when failure happens
*/
NNVM_DLL int NNGraphGetSymbol(GraphHandle graph, SymbolHandle *symbol);
/*!
* \brief Get Set a attribute in json format.
* This feature allows pass graph attributes back and forth in reasonable speed.
*
* \param handle The graph handle.
* \param key The key to the attribute.
* \param json_value The value need to be in format [type_name, value],
* Where type_name is a registered type string in C++ side via DMLC_JSON_ENABLE_ANY.
* \return 0 when success, -1 when failure happens
*/
NNVM_DLL int NNGraphSetJSONAttr(GraphHandle handle,
const char* key,
const char* json_value);
/*!
* \brief Get a serialized attrirbute from graph.
* This feature allows pass graph attributes back and forth in reasonable speed.
*
* \param handle The graph handle.
* \param key The key to the attribute.
* \param json_out The result attribute, can be NULL if the attribute do not exist.
* The json_out is an array of [type_name, value].
* Where the type_name is a registered type string in C++ side via DMLC_JSON_ENABLE_ANY.
* \param success Whether the result is contained in out.
* \return 0 when success, -1 when failure happens
*/
NNVM_DLL int NNGraphGetJSONAttr(SymbolHandle handle,
const char* key,
const char** json_out,
int *success);
/*!
* \brief Set a attribute whose type is std::vector<NodeEntry> in c++
* This feature allows pass List of symbolic variables for gradient request.
*
* \note This is beta feature only used for test purpos
*
* \param handle The graph handle.
* \param key The key to the attribute.
* \param list The symbol whose outputs represents the list of NodeEntry to be passed.
* \return 0 when success, -1 when failure happens
*/
NNVM_DLL int NNGraphSetNodeEntryListAttr_(GraphHandle handle,
const char* key,
SymbolHandle list);
/*!
* \brief Apply passes on the src graph.
* \param src The source graph handle.
* \param num_pass The number of pass to be applied.
* \param pass_names The names of the pass.
* \param dst The result graph.
* \return 0 when success, -1 when failure happens
*/
NNVM_DLL int NNGraphApplyPasses(GraphHandle src,
nn_uint num_pass,
const char** pass_names,
GraphHandle *dst);
#endif // NNVM_C_API_H_
//===== EXPANDED: ../nnvm/include/nnvm/c_api.h =====
/*! \brief macro to guard beginning and end section of all functions */
#define API_BEGIN() try {
/*! \brief every function starts with API_BEGIN();
and finishes with API_END() or API_END_HANDLE_ERROR */
#define API_END() } catch(dmlc::Error &_except_) { return NNAPIHandleException(_except_); } return 0; // NOLINT(*)
/*!
* \brief every function starts with API_BEGIN();
* and finishes with API_END() or API_END_HANDLE_ERROR
* The finally clause contains procedure to cleanup states when an error happens.
*/
#define API_END_HANDLE_ERROR(Finalize) } catch(dmlc::Error &_except_) { Finalize; return NNAPIHandleException(_except_); } return 0; // NOLINT(*)
/*! \brief entry to to easily hold returning information */
struct NNAPIThreadLocalEntry {
/*! \brief result holder for returning string */
std::string ret_str;
/*! \brief result holder for returning strings */
std::vector<std::string> ret_vec_str;
/*! \brief result holder for returning string pointers */
std::vector<const char *> ret_vec_charp;
/*! \brief result holder for returning handles */
std::vector<void *> ret_handles;
/*! \brief argument holder to hold symbol */
std::unordered_map<std::string, const nnvm::Symbol*> kwarg_symbol;
};
/*! \brief Thread local store that can be used to hold return values. */
typedef dmlc::ThreadLocalStore<NNAPIThreadLocalEntry> NNAPIThreadLocalStore;
/*!
* \brief handle exception throwed out
* \param e the exception
* \return the return value of API after exception is handled
*/
inline int NNAPIHandleException(const dmlc::Error &e) {
NNAPISetLastError(e.what());
return -1;
}
#endif // NNVM_C_API_C_API_COMMON_H_
//===== EXPANDED: ../nnvm/src/c_api/c_api_common.h =====
struct ErrorEntry {
std::string last_error;
};
typedef dmlc::ThreadLocalStore<ErrorEntry> NNAPIErrorStore;
const char *NNGetLastError() {
return NNAPIErrorStore::Get()->last_error.c_str();
}
void NNAPISetLastError(const char* msg) {
NNAPIErrorStore::Get()->last_error = msg;
}
//===== EXPANDED: ../nnvm/src/c_api/c_api_error.cc =====
//===== EXPANDING: ../nnvm/src/c_api/c_api_graph.cc =====
/*!
* Copyright (c) 2016 by Contributors
* \file c_api_graph.cc
* \brief C API related to Graph IR.
*/
using namespace nnvm;
int NNGraphCreate(SymbolHandle symbol, GraphHandle *graph) {
Graph* g = new Graph();
API_BEGIN();
g->outputs = static_cast<Symbol*>(symbol)->outputs;
*graph = g;
API_END_HANDLE_ERROR(delete g);
}
int NNGraphFree(GraphHandle handle) {
API_BEGIN();
delete static_cast<Graph*>(handle);
API_END();
}
int NNGraphGetSymbol(GraphHandle graph, SymbolHandle *symbol) {
Symbol* s = new Symbol();
API_BEGIN();
s->outputs = static_cast<Graph*>(graph)->outputs;
*symbol = s;
API_END_HANDLE_ERROR(delete s);
}
int NNGraphSetNodeEntryListAttr_(GraphHandle handle,
const char* key,
SymbolHandle list) {
API_BEGIN();
Symbol* s = static_cast<Symbol*>(list);
Graph* g = static_cast<Graph*>(handle);
g->attrs[std::string(key)]
= std::make_shared<any>(s->outputs);
API_END();
}
int NNGraphSetJSONAttr(GraphHandle handle,
const char* key,
const char* json_value) {
API_BEGIN();
Graph* g = static_cast<Graph*>(handle);
std::string temp(json_value);
std::istringstream is(temp);
dmlc::JSONReader reader(&is);
nnvm::any value;
reader.Read(&value);
g->attrs[std::string(key)] = std::make_shared<any>(std::move(value));
API_END();
}
int NNGraphGetJSONAttr(GraphHandle handle,
const char* key,
const char** json_out,
int *success) {
NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get();
API_BEGIN();
Graph* g = static_cast<Graph*>(handle);
std::string skey(key);
auto it = g->attrs.find(skey);
if (it != g->attrs.end()) {
std::ostringstream os;
dmlc::JSONWriter writer(&os);
writer.Write(*it->second.get());
ret->ret_str = os.str();
*json_out = (ret->ret_str).c_str();
*success = 1;
} else {
*success = 0;
}
API_END();
}
int NNGraphApplyPasses(GraphHandle src,
nn_uint num_pass,
const char** pass_names,
GraphHandle *dst) {
Graph* g = new Graph();
API_BEGIN();
std::vector<std::string> vpass;
for (nn_uint i = 0; i < num_pass; ++i) {
vpass.emplace_back(std::string(pass_names[i]));
}
*g = ApplyPasses(*static_cast<Graph*>(src), vpass);
*dst = g;
API_END_HANDLE_ERROR(delete g);
}
//===== EXPANDED: ../nnvm/src/c_api/c_api_graph.cc =====
//===== EXPANDING: ../nnvm/src/c_api/c_api_symbolic.cc =====
/*!
* Copyright (c) 2016 by Contributors
* \file c_api_symbolic.cc
* \brief C API related to symbolic graph compsition.
*/
using namespace nnvm;
int NNListAllOpNames(nn_uint *out_size,
const char*** out_array) {
API_BEGIN();
NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get();
ret->ret_vec_str = dmlc::Registry<Op>::ListAllNames();
ret->ret_vec_charp.clear();
for (size_t i = 0; i < ret->ret_vec_str.size(); ++i) {
ret->ret_vec_charp.push_back(ret->ret_vec_str[i].c_str());
}
*out_array = dmlc::BeginPtr(ret->ret_vec_charp);
*out_size = static_cast<nn_uint>(ret->ret_vec_str.size());
API_END();
}
int NNGetOpHandle(const char* op_name,
OpHandle* op_out) {
API_BEGIN();
*op_out = (OpHandle)Op::Get(op_name); // NOLINT(*)
API_END();
}
int NNListUniqueOps(nn_uint *out_size,
OpHandle **out_array) {
API_BEGIN();
auto &vec = dmlc::Registry<Op>::List();
*out_size = static_cast<nn_uint>(vec.size());
*out_array = (OpHandle*)(dmlc::BeginPtr(vec)); // NOLINT(*)
API_END();
}
int NNAddControlDeps(SymbolHandle handle,
SymbolHandle src_dep) {
API_BEGIN();
static_cast<Symbol*>(handle)->AddControlDeps(
*static_cast<Symbol*>(src_dep));
API_END();
}
int NNGetOpInfo(OpHandle handle,
const char **name,
const char **description,
nn_uint *num_doc_args,
const char ***arg_names,
const char ***arg_type_infos,
const char ***arg_descriptions,
const char **return_type) {
const Op *op = static_cast<const Op *>(handle);
NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get();
API_BEGIN();
*name = op->name.c_str();
*description = op->description.c_str();
*num_doc_args = static_cast<nn_uint>(op->arguments.size());
if (return_type) *return_type = nullptr;
ret->ret_vec_charp.clear();
for (size_t i = 0; i < op->arguments.size(); ++i) {
ret->ret_vec_charp.push_back(op->arguments[i].name.c_str());
}
for (size_t i = 0; i < op->arguments.size(); ++i) {
ret->ret_vec_charp.push_back(op->arguments[i].type_info_str.c_str());
}
for (size_t i = 0; i < op->arguments.size(); ++i) {
ret->ret_vec_charp.push_back(op->arguments[i].description.c_str());
}
*arg_names = dmlc::BeginPtr(ret->ret_vec_charp);
*arg_type_infos = dmlc::BeginPtr(ret->ret_vec_charp) + op->arguments.size();
*arg_descriptions = dmlc::BeginPtr(ret->ret_vec_charp) + (op->arguments.size() * 2);
API_END();
}
int NNSymbolCreateAtomicSymbol(OpHandle creator,
nn_uint num_param,
const char **keys,
const char **vals,
SymbolHandle *out) {
Symbol *s = new Symbol();
API_BEGIN();
const Op* op = static_cast<const Op*>(creator);
std::unordered_map<std::string, std::string> kwargs;
for (nn_uint i = 0; i < num_param; ++i) {
kwargs.insert({std::string(keys[i]), std::string(vals[i])});
}
*s = Symbol::CreateFunctor(op, std::move(kwargs));
*out = s;
API_END_HANDLE_ERROR(delete s;);
}
int NNSymbolCreateVariable(const char *name, SymbolHandle *out) {
Symbol *s = new Symbol();
API_BEGIN();
*s = Symbol::CreateVariable(name);
*out = s;
API_END_HANDLE_ERROR(delete s);
}
int NNSymbolCreateGroup(nn_uint num_symbols,
SymbolHandle *symbols,
SymbolHandle *out) {
Symbol *s = new Symbol();
Symbol **sym_arr = (Symbol**)symbols; // NOLINT(*)
API_BEGIN();
std::vector<Symbol> syms;
for (nn_uint i = 0; i < num_symbols; ++i) {
syms.push_back(*sym_arr[i]);
}
*s = Symbol::CreateGroup(syms);
*out = s;
API_END_HANDLE_ERROR(delete s);
}
int NNSymbolGetOutput(SymbolHandle symbol,
nn_uint index,
SymbolHandle *out) {
Symbol *s = new Symbol();
API_BEGIN();
*s = (*static_cast<Symbol*>(symbol))[index];
*out = s;
API_END_HANDLE_ERROR(delete s);
}
int NNSymbolGetInternals(SymbolHandle symbol,
SymbolHandle *out) {
Symbol *s = new Symbol();
API_BEGIN();
*s = static_cast<Symbol*>(symbol)->GetInternals();
*out = s;
API_END_HANDLE_ERROR(delete s);
}
int NNSymbolFree(SymbolHandle symbol) {
API_BEGIN();
delete static_cast<Symbol*>(symbol);
API_END();
}
int NNSymbolCopy(SymbolHandle symbol, SymbolHandle *out) {
Symbol *s = new Symbol();
API_BEGIN();
*s = static_cast<const Symbol*>(symbol)->Copy();
*out = s;
API_END_HANDLE_ERROR(delete s);
}
int NNSymbolPrint(SymbolHandle symbol, const char **out_str) {
Symbol *s = static_cast<Symbol*>(symbol);
NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get();
API_BEGIN();
std::ostringstream os;
s->Print(os);
ret->ret_str = os.str();
*out_str = (ret->ret_str).c_str();
API_END();
}
int NNSymbolGetAttr(SymbolHandle symbol,
const char* key,
const char** out,
int* success) {
Symbol *s = static_cast<Symbol*>(symbol);
NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get();
API_BEGIN();
if (s->GetAttr(key, &(ret->ret_str))) {
*out = (ret->ret_str).c_str();
*success = 1;
} else {
*out = nullptr;
*success = 0;
}
API_END();
}
int NNSymbolSetAttrs(SymbolHandle symbol,
nn_uint num_param,
const char** keys,
const char** vals) {
Symbol *s = static_cast<Symbol*>(symbol);
API_BEGIN();
std::vector<std::pair<std::string, std::string> > kwargs;
for (nn_uint i = 0; i < num_param; ++i) {
kwargs.emplace_back(
std::make_pair(std::string(keys[i]), std::string(vals[i])));
}
s->SetAttrs(kwargs);
API_END();
}
int NNSymbolListAttrs(SymbolHandle symbol,
int option,
nn_uint *out_size,
const char*** out) {
Symbol *s = static_cast<Symbol*>(symbol);
NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get();
API_BEGIN();
std::unordered_map<std::string, std::string> attr =
s->ListAttrs(static_cast<Symbol::ListAttrOption>(option)); // NOLINT(*)
std::vector<std::string>& attr_list = ret->ret_vec_str;
attr_list.clear();
for (const auto& kv : attr) {
attr_list.push_back(kv.first);
attr_list.push_back(kv.second);
}
*out_size = attr.size();
ret->ret_vec_charp.clear();
for (size_t i = 0; i < ret->ret_vec_str.size(); ++i) {
ret->ret_vec_charp.push_back(ret->ret_vec_str[i].c_str());
}
*out = dmlc::BeginPtr(ret->ret_vec_charp);
API_END();
}
int NNSymbolListInputVariables(SymbolHandle symbol,
int option,
nn_uint *out_size,
SymbolHandle** out_sym_array) {
Symbol *s = static_cast<Symbol*>(symbol);
NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get();
API_BEGIN();
std::vector<NodePtr> vs = s->ListInputs(Symbol::ListInputOption(option));
ret->ret_handles.clear();
for (size_t i = 0; i < vs.size(); ++i) {
nnvm::Symbol* rs = new nnvm::Symbol();
rs->outputs.push_back(NodeEntry{vs[i], 0, 0});
ret->ret_handles.push_back(rs);
}
*out_size = static_cast<nn_uint>(vs.size());
*out_sym_array = dmlc::BeginPtr(ret->ret_handles);
API_END();
}
int NNSymbolListInputNames(SymbolHandle symbol,
int option,
nn_uint *out_size,
const char ***out_str_array) {
Symbol *s = static_cast<Symbol*>(symbol);
NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get();
API_BEGIN();
ret->ret_vec_str =
s->ListInputNames(Symbol::ListInputOption(option));
ret->ret_vec_charp.clear();
for (size_t i = 0; i < ret->ret_vec_str.size(); ++i) {
ret->ret_vec_charp.push_back(ret->ret_vec_str[i].c_str());
}
*out_size = static_cast<nn_uint>(ret->ret_vec_charp.size());
*out_str_array = dmlc::BeginPtr(ret->ret_vec_charp);
API_END();
}
int NNSymbolListOutputNames(SymbolHandle symbol,
nn_uint *out_size,
const char ***out_str_array) {
Symbol *s = static_cast<Symbol*>(symbol);
NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get();
API_BEGIN();
ret->ret_vec_str = s->ListOutputNames();
ret->ret_vec_charp.clear();
for (size_t i = 0; i < ret->ret_vec_str.size(); ++i) {
ret->ret_vec_charp.push_back(ret->ret_vec_str[i].c_str());
}
*out_size = static_cast<nn_uint>(ret->ret_vec_charp.size());
*out_str_array = dmlc::BeginPtr(ret->ret_vec_charp);
API_END();
}
int NNSymbolCompose(SymbolHandle sym,
const char *name,
nn_uint num_args,
const char** keys,
SymbolHandle* args) {
API_BEGIN();
NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get();
std::string& s_name = ret->ret_str;
std::unordered_map<std::string, const Symbol*>& kwargs
= ret->kwarg_symbol;
kwargs.clear();
if (name != nullptr) {
s_name = name;
} else {
s_name.clear();
}
Symbol* s = static_cast<Symbol*>(sym);
if (keys == nullptr && num_args != 0) {
kwargs.clear();
array_view<const Symbol*> parg(
(Symbol**)args, (Symbol**)args + num_args); // NOLINT(*)
s->Compose(parg, kwargs, s_name);
} else {
for (nn_uint i = 0; i < num_args; ++i) {
kwargs[keys[i]] = (Symbol*)args[i]; // NOLINT(*)
}
s->Compose(array_view<const Symbol*>(), kwargs, s_name);
}
API_END();
}
//===== EXPANDED: ../nnvm/src/c_api/c_api_symbolic.cc =====
//===== EXPANDED: nnvm.cc =====
//===== EXPANDING: mxnet_predict0.cc =====
// mexnet.cc
#if defined(__ANDROID__) || defined(__MXNET_JS__)
#define MSHADOW_USE_SSE 0
#endif
//===== EXPANDING: ../src/ndarray/ndarray_function.cc =====
/*!
* Copyright (c) 2015 by Contributors
* \file ndarray_function_cpu.cc
* \brief CPU Implementation of ndarray function.
*/
// this will be invoked by gcc and compile CPU version
//===== EXPANDING: ../src/ndarray/ndarray_function.h =====
/*!
* Copyright (c) 2015 by Contributors
* \file ndarray_op.h
* \brief the real execution functions of ndarray operations
*/
#ifndef MXNET_NDARRAY_NDARRAY_FUNCTION_H_
#define MXNET_NDARRAY_NDARRAY_FUNCTION_H_
//===== EXPANDING: ../include/mxnet/resource.h =====
/*!
* Copyright (c) 2015 by Contributors
* \file resource.h
* \brief Global resource allocation handling.
*/
#ifndef MXNET_RESOURCE_H_
#define MXNET_RESOURCE_H_
//===== EXPANDING: ../include/mxnet/engine.h =====
/*!
* Copyright (c) 2015 by Contributors
* \file engine.h
* \brief Engine that schedules all the operations according to dependency.
*/
#ifndef MXNET_ENGINE_H_
#define MXNET_ENGINE_H_
#if DMLC_USE_CXX11
#endif
namespace mxnet {
// forward declare engine
class Engine;
/*! \brief namespace of engine internal types. */
namespace engine {
/*! \brief Internal representation of variable. */
struct Var;
/*! \brief Internal representation of operator. */
struct Opr;
/*! \brief Variable pointer type, usually hold by user used to specify dependencies. */
typedef Var* VarHandle;
/*! \brief Operator pointer type, usually hold by user.*/
typedef Opr* OprHandle;
/*!
* \brief OnComplete Callback to the engine,
* called by AsyncFn when action completes
*/
class CallbackOnComplete {
public:
// use implicit copy and assign
/*! \brief involve the callback */
inline void operator()() const {
(*callback_)(engine_, param_);
}
private:
/*! \brief engine can see content of callback */
friend class ::mxnet::Engine;
/*! \brief the real callback */
void (*callback_)(Engine *, void *);
/*! \brief the engine class passed to callback */
Engine* engine_;
/*! \brief the parameter set on callback */
void* param_;
};
} // namespace engine
#if DMLC_USE_CXX11
/*! \brief Function property, used to hint what action is pushed to engine. */
enum class FnProperty {
/*! \brief Normal operation */
kNormal,
/*! \brief Copy operation from GPU to other devices */
kCopyFromGPU,
/*! \brief Copy operation from CPU to other devices */
kCopyToGPU,
/*! \brief Prioritized sync operation on CPU */
kCPUPrioritized,
/*! \brief Asynchronous function call */
kAsync
}; // enum class FnProperty
/*!
* \brief Dependency engine that schedules operations.
*/
class MXNET_API Engine {
public:
/*! \brief callback on complete*/
typedef engine::CallbackOnComplete CallbackOnComplete;
/*! \brief Synchronous operation to pass to engine. */
typedef std::function<void(RunContext)> SyncFn;
/*! \brief Asynchronous operation to pass to engine. */
typedef std::function<void(RunContext, CallbackOnComplete)> AsyncFn;
/*! \brief Variable pointer */
typedef engine::VarHandle VarHandle;
/*! \brief Operator pointer */
typedef engine::OprHandle OprHandle;
/*!
* \brief Notify the engine about a shutdown,
* This can help engine to print less messages into display.
*
* User do not have to call this function.
* \return 0 when success, -1 when failure happens.
*/
virtual void NotifyShutdown() = 0;
/*!
* \brief Allocate a new variable, the variable can then
* be used to schedule the operation concurrently via dependency
* patterns.
* \return The new variable allocated.
*/
virtual VarHandle NewVariable() = 0;
/*!
* \brief Create a new operator. The returned operator could be saved
* externally so that it could be resued for scheduling.
* \param fn The execution function.
* \param const_vars The variables that current operation will use but not
* mutate.
* \param mutable_vars The variables that current operation will mutate.
* \param prop Property of the function.
* \param opr_name The operator name.
* \return The new operator allocated.
*/
virtual OprHandle NewOperator(AsyncFn fn,
std::vector<VarHandle> const& const_vars,
std::vector<VarHandle> const& mutable_vars,
FnProperty prop = FnProperty::kNormal,
const char* opr_name = nullptr) = 0;
/*!
* \brief Delete the given operator.
* \param op The operator to delete.
*
* The delete will not happen immediately, but will wait until all the
* operations using this operator are completed.
*/
virtual void DeleteOperator(OprHandle op) = 0;
/*!
* \brief Push an operator to the engine.
* \param op The operator to push.
* \param exec_ctx Execution context.
* \param priority Priority of the action, as hint to the engine.
* \param profiling The variable indicate whether to profile this operator.
*/
virtual void Push(OprHandle op, Context exec_ctx, int priority = 0, bool profiling = false) = 0;
/*!
* \brief Push an asynchronous operation to the engine.
* \param exec_fun Execution function, this function takes a parameter
* on_complete that must be called when the execution
* completes.
* \param exec_ctx Execution context.
* \param const_vars The variables that current operation will use but not
* mutate.
* \param mutable_vars The variables that current operation will mutate.
* \param prop Property of the function.
* \param priority Priority of the action, as hint to the engine.
* \param opr_name The operator name.
*/
virtual void PushAsync(AsyncFn exec_fun, Context exec_ctx,
std::vector<VarHandle> const& const_vars,
std::vector<VarHandle> const& mutable_vars,
FnProperty prop = FnProperty::kNormal,
int priority = 0,
const char* opr_name = nullptr) = 0;
/*!
* \brief Schedule the deletion of a variable.
*
* The delete will not happen immediately, but will wait until all the
* operations depending on var are completed.
*
* \param delete_fn A function that will be called after the variable is
* deleted.
* \param exec_ctx Execution context.
* \param var The variable to be deleted.
*/
virtual void DeleteVariable(SyncFn delete_fn,
Context exec_ctx,
VarHandle var) = 0;
/*!
* \brief Wait for a variable.
* \param var The variable we should wait for. This function returns when the
* variable is ready.
*/
virtual void WaitForVar(VarHandle var) = 0;
/*!
* \brief Wait until all the activity of engine finishes.
*/
virtual void WaitForAll() = 0;
/*!\brief virtual destructor */
virtual ~Engine() noexcept(false) {}
/*!
* \return Engine singleton.
*/
static Engine* Get();
/*!
* \brief Get shared pointer reference to engine singleton.
* Most user should not call this function.
* This function is called by another singleton X who requires
* engine to be destructed after X.
*
* \return A shared pointer to Engine singleton.
*/
static std::shared_ptr<Engine> _GetSharedRef();
/*!
* \brief Push an synchronous operation to the engine.
* \param exec_fn Execution function that executes the operation.
* \param exec_ctx Execution context.
* \param const_vars The variables that current operation will use but not
* mutate.
* \param mutable_vars The variables that current operation will mutate.
* \param prop Property of the function.
* \param priority Priority of the action, as hint to the engine.
* \param opr_name The operator name.
* \tparam SyncFn the synchronous function to be pushed.
*/
inline void PushSync(SyncFn exec_fn, Context exec_ctx,
std::vector<VarHandle> const& const_vars,
std::vector<VarHandle> const& mutable_vars,
FnProperty prop = FnProperty::kNormal,
int priority = 0,
const char* opr_name = nullptr) {
this->PushAsync([exec_fn](RunContext ctx, CallbackOnComplete on_complete) {
exec_fn(ctx);
on_complete();
}, exec_ctx, const_vars, mutable_vars, prop, priority, opr_name);
}
/*!
* \brief factory function to create OnComplete callback.
* \param callback th static callback function.
* \param param the paramter passed to callback.
*/
inline CallbackOnComplete CreateCallback(
void (*callback)(Engine *, void *), void *param) {
CallbackOnComplete ret;
ret.callback_ = callback;
ret.engine_ = this;
ret.param_ = param;
return ret;
}
}; // class Engine
#endif // DMLC_USE_CXX11
} // namespace mxnet
#endif // MXNET_ENGINE_H_
//===== EXPANDED: ../include/mxnet/engine.h =====
namespace mxnet {
/*!
* \brief The resources that can be requested by Operator
*/
struct ResourceRequest {
/*! \brief Resource type, indicating what the pointer type is */
enum Type {
/*! \brief mshadow::Random<xpu> object */
kRandom,
/*! \brief A dynamic temp space that can be arbitrary size */
kTempSpace
};
/*! \brief type of resources */
Type type;
/*! \brief default constructor */
ResourceRequest() {}
/*!
* \brief constructor, allow implicit conversion
* \param type type of resources
*/
ResourceRequest(Type type) // NOLINT(*)
: type(type) {}
};
/*!
* \brief Resources used by mxnet operations.
* A resource is something special other than NDArray,
* but will still participate
*/
struct Resource {
/*! \brief The original request */
ResourceRequest req;
/*! \brief engine variable */
engine::VarHandle var;
/*! \brief identifier of id information, used for debug purpose */
int32_t id;
/*!
* \brief pointer to the resource, do not use directly,
* access using member functions
*/
void *ptr_;
/*! \brief default constructor */
Resource() : id(0) {}
/*!
* \brief Get random number generator.
* \param stream The stream to use in the random number generator.
* \return the mshadow random number generator requested.
* \tparam xpu the device type of random number generator.
*/
template<typename xpu, typename DType>
inline mshadow::Random<xpu, DType>* get_random(
mshadow::Stream<xpu> *stream) const {
CHECK_EQ(req.type, ResourceRequest::kRandom);
mshadow::Random<xpu, DType> *ret =
static_cast<mshadow::Random<xpu, DType>*>(ptr_);
ret->set_stream(stream);
return ret;
}
/*!
* \brief Get space requested as mshadow Tensor.
* The caller can request arbitrary size.
*
* This space can be shared with other calls to this->get_space.
* So the caller need to serialize the calls when using the conflicted space.
* The old space can get freed, however, this will incur a synchronization,
* when running on device, so the launched kernels that depend on the temp space
* can finish correctly.
*
* \param shape the Shape of returning tensor.
* \param stream the stream of retruning tensor.
* \return the mshadow tensor requested.
* \tparam xpu the device type of random number generator.
* \tparam ndim the number of dimension of the tensor requested.
*/
template<typename xpu, int ndim>
inline mshadow::Tensor<xpu, ndim, real_t> get_space(
mshadow::Shape<ndim> shape, mshadow::Stream<xpu> *stream) const {
return get_space_typed<xpu, ndim, real_t>(shape, stream);
}
/*!
* \brief Get cpu space requested as mshadow Tensor.
* The caller can request arbitrary size.
*
* \param shape the Shape of returning tensor.
* \return the mshadow tensor requested.
* \tparam ndim the number of dimension of the tensor requested.
*/
template<int ndim>
inline mshadow::Tensor<cpu, ndim, real_t> get_host_space(
mshadow::Shape<ndim> shape) const {
return get_host_space_typed<cpu, ndim, real_t>(shape);
}
/*!
* \brief Get space requested as mshadow Tensor in specified type.
* The caller can request arbitrary size.
*
* \param shape the Shape of returning tensor.
* \param stream the stream of retruning tensor.
* \return the mshadow tensor requested.
* \tparam xpu the device type of random number generator.
* \tparam ndim the number of dimension of the tensor requested.
*/
template<typename xpu, int ndim, typename DType>
inline mshadow::Tensor<xpu, ndim, DType> get_space_typed(
mshadow::Shape<ndim> shape, mshadow::Stream<xpu> *stream) const {
CHECK_EQ(req.type, ResourceRequest::kTempSpace);
return mshadow::Tensor<xpu, ndim, DType>(
reinterpret_cast<DType*>(get_space_internal(shape.Size() * sizeof(DType))),
shape, shape[ndim - 1], stream);
}
/*!
* \brief Get CPU space as mshadow Tensor in specified type.
* The caller can request arbitrary size.
*
* \param shape the Shape of returning tensor
* \return the mshadow tensor requested
* \tparam ndim the number of dimnesion of tensor requested
* \tparam DType request data type
*/
template<int ndim, typename DType>
inline mshadow::Tensor<cpu, ndim, DType> get_host_space_typed(
mshadow::Shape<ndim> shape) const {
return mshadow::Tensor<cpu, ndim, DType>(
reinterpret_cast<DType*>(get_host_space_internal(shape.Size() * sizeof(DType))),
shape, shape[ndim - 1], NULL);
}
/*!
* \brief internal function to get space from resources.
* \param size The size of the space.
* \return The allocated space.
*/
void* get_space_internal(size_t size) const;
/*!
* \brief internal function to get cpu space from resources.
* \param size The size of space.
* \return The allocated space
*/
void *get_host_space_internal(size_t size) const;
};
/*! \brief Global resource manager */
class ResourceManager {
public:
/*!
* \brief Get resource of requested type.
* \param ctx the context of the request.
* \param req the resource request.
* \return the requested resource.
* \note The returned resource's ownership is
* still hold by the manager singleton.
*/
virtual Resource Request(Context ctx, const ResourceRequest &req) = 0;
/*!
* \brief Seed all the allocated random numbers.
* \param seed the seed to the random number generators on all devices.
*/
virtual void SeedRandom(uint32_t seed) = 0;
/*! \brief virtual destructor */
virtual ~ResourceManager() DMLC_THROW_EXCEPTION {}
/*!
* \return Resource manager singleton.
*/
static ResourceManager *Get();
};
} // namespace mxnet
#endif // MXNET_RESOURCE_H_
//===== EXPANDED: ../include/mxnet/resource.h =====
//===== EXPANDING: ../src/operator/mshadow_op.h =====
/*!
* Copyright (c) 2015 by Contributors
* \file mshadow_op.h
* \brief
* \author Bing Xu
*/
#ifndef MXNET_OPERATOR_MSHADOW_OP_H_
#define MXNET_OPERATOR_MSHADOW_OP_H_
//===== EXPANDING: ../src/operator/special_functions-inl.h =====
/*!
* Copyright (c) 2015 by Contributors
* \file special_functions-inl.h
* \brief
* \author Valentin Flunkert
*/
#ifndef MXNET_OPERATOR_SPECIAL_FUNCTIONS_INL_H_
#define MXNET_OPERATOR_SPECIAL_FUNCTIONS_INL_H_
namespace mxnet {
namespace op {
namespace special_functions {
template<typename DType>
struct helper_numeric_limits {
MSHADOW_XINLINE static DType max();
};
template<>
struct helper_numeric_limits<double> {
MSHADOW_XINLINE static double max() {
return DBL_MAX;
}
};
template<>
struct helper_numeric_limits<float> {
MSHADOW_XINLINE static double max() {
return FLT_MAX;
}
};
// This code is based on the Cephes Library availible at http://www.netlib.org/cephes
// The original author, Stephen Moshier, has kindly given permission to use this code
// in mxnet. (See email below).
//
// Date: Tue, 13 Sep 2016 09:28:20 -0400
// From: Stephen Moshier
// To: Flunkert, Valentin
// Subject: Re: cephes code in mxnet
//
// Hello Valentin,
//
// Thank you for writing. You are welcome to use and modify the Cephes code
// and distribute it under the Apache license.
//
// Good luck with your project,
// Steve Moshier
//
// Cephes Math Library Release 2.2: June, 1992
// Copyright 1984, 1987, 1992 by Stephen L. Moshier
// Direct inquiries to 30 Frost Street, Cambridge, MA 02140
//
struct cephes {
/*
* Helper to evaluate a polynomial given an array of coefficients.
*/
template <typename DType>
MSHADOW_XINLINE static DType polevl(DType x, const DType coef[], int N) {
DType ans;
DType const *p;
int i;
p = coef;
ans = *p++;
i = N;
do {
ans = ans * x + *p++;
} while ( --i );
return( ans );
}
/*
* Helper function for psi that handles double/float specific differences
* in the algorithm.
*/
template<typename DType>
MSHADOW_XINLINE static DType psi_helper(DType s);
/*
*
* Psi (digamma) function
*
*
* SYNOPSIS:
*
* float x, y, psif();
*
* y = psif( x );
*
*
* DESCRIPTION:
*
* d -
* psi(x) = -- ln | (x)
* dx
*
* is the logarithmic derivative of the gamma function.
* For integer x,
* n-1
* -
* psi(n) = -EUL + > 1/k.
* -
* k=1
*
* This formula is used for 0 < n <= 10. If x is negative, it
* is transformed to a positive argument by the reflection
* formula psi(1-x) = psi(x) + pi cot(pi x).
* For general positive x, the argument is made greater than 10
* using the recurrence psi(x+1) = psi(x) + 1/x.
* Then the following asymptotic expansion is applied:
*
* inf. B
* - 2k
* psi(x) = log(x) - 1/2x - > -------
* - 2k
* k=1 2k x
*
* where the B2k are Bernoulli numbers.
*
* ACCURACY:
* Absolute error, relative when |psi| > 1 :
* arithmetic domain # trials peak rms
* IEEE -33,0 30000 8.2e-7 1.2e-7
* IEEE 0,33 100000 7.3e-7 7.7e-8
*
* ERROR MESSAGES:
* message condition value returned
* psi singularity x integer <=0 MAXNUMF
*/
template<typename DType>
MSHADOW_XINLINE static DType psi(DType x) {
DType p, q, nz, s, w, y;
int i, n, negative;
DType EUL(0.57721566490153286061);
DType PI(3.14159265358979323846);
negative = 0;
nz = 0.0;
if ( x <= 0.0 ) {
negative = 1;
q = x;
p = std::floor(q);
if ( p == q ) {
return helper_numeric_limits<double>::max();
}
/* Remove the zeros of tan(PI x)
* by subtracting the nearest integer from x
*/
nz = q - p;
if ( nz != 0.5 ) {
if ( nz > 0.5 ) {
p += 1.0;
nz = q - p;
}
nz = PI/std::tan(PI*nz);
} else {
nz = 0.0;
}
x = 1.0 - x;
}
/* check for positive integer up to 10 */
if ( (x <= 10.0) && (x == std::floor(x)) ) {
y = 0.0;
n = x;
for ( i = 1; i < n; i++ ) {
w = i;
y += 1.0/w;
}
y -= EUL;
goto done;
}
s = x;
w = 0.0;
while ( s < 10.0 ) {
w += 1.0/s;
s += 1.0;
}
y = psi_helper(s);
y = logf(s) - (0.5/s) - y - w;
done:
if ( negative ) {
y -= nz;
}
return(y);
}
};
template<>
MSHADOW_XINLINE double cephes::psi_helper<double>(double s) {
double z;
const double A[] = {
8.33333333333333333333E-2,
-2.10927960927960927961E-2,
7.57575757575757575758E-3,
-4.16666666666666666667E-3,
3.96825396825396825397E-3,
-8.33333333333333333333E-3,
8.33333333333333333333E-2
};
if ( s < 1.0e17 ) {
z = 1.0/(s * s);
return z * cephes::polevl<double>(z, A, 6);
} else {
return 0.0;
}
}
template<>
MSHADOW_XINLINE float cephes::psi_helper<float>(float s) {
float z;
const float A[] = {
-4.16666666666666666667E-3f,
3.96825396825396825397E-3f,
-8.33333333333333333333E-3f,
8.33333333333333333333E-2f
};
if ( s < 1.0e8 ) {
z = 1.0/(s * s);
return z * cephes::polevl<float>(z, A, 3);
} else {
return 0.0;
}
}
} // namespace special_functions
} // namespace op
} // namespace mxnet
#endif // MXNET_OPERATOR_SPECIAL_FUNCTIONS_INL_H_
//===== EXPANDED: ../src/operator/special_functions-inl.h =====
namespace mxnet {
namespace op {
namespace mshadow_op {
#ifdef __CUDA_ARCH__
__constant__ const float PI = 3.14159265358979323846;
#else
const float PI = 3.14159265358979323846;
using std::isnan;
#endif
/*! \brief identity Operation */
struct identity {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a) {
return DType(a);
}
};
struct identity_grad {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a) {
return DType(DType(1.0f));
}
};
struct left {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a, DType b) {
return a;
}
};
struct right {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a, DType b) {
return b;
}
};
struct negation {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a) {
return DType(-a);
}
};
/*! \brief sigmoid unit */
struct sigmoid {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a) {
return DType(DType(1.0f) / (DType(1.0f) + expf(-a)));
}
};
struct sigmoid_grad {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a) {
return DType(a * (DType(1.0f) - a));
}
};
/*! \brief Rectified Linear Operation */
struct relu {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a) {
return DType(a > DType(0.0f) ? a : DType(0.0f));
}
};
struct relu_grad {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a) {
return DType(a > DType(0.0f) ? DType(1.0f) : DType(0.0f));
}
};
/*! \brief Leaky ReLU Operation */
struct xelu {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a, DType b) {
return DType(a > DType(0.0f) ? a : a * b);
}
};
struct xelu_grad {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a, DType b) {
return DType(a > DType(0.0f) ? DType(1.0f) : b);
}
};
/*! \brief Exponential Linear Unit */
struct elu {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType x, DType a) {
return DType(x > DType(0.0f) ? x : a * (expf(x) - DType(1.0f)));
}
};
struct elu_grad {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType x, DType a) {
return DType(x > DType(0.0f) ? DType(1.0f) : a + x);
}
};
struct tanh {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a) {
return DType(tanhf( a ));
}
};
struct tanh_grad {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a) {
return DType(DType(1.0f) - a * a);
}
};
/*! \brief SoftReLU, also known as softplus activation. */
struct softrelu {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a) {
return DType(log1pf(expf(a)));
}
};
struct softrelu_grad {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a) {
return DType(DType(1.0f) - expf(-a));
}
};
struct exp {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a) {
return DType(expf(a));
}
};
struct expm1 {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a) {
return DType(expm1f(a));
}
};
struct log {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a) {
return DType(logf(a));
}
};
struct log10 {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a) {
return DType(log10f(a));
}
};
struct log2 {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a) {
return DType(log2f(a));
}
};
struct log_grad {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a) {
return DType(DType(1.0f) / a);
}
};
struct sin {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a) {
return DType(sinf(a));
}
};
struct sin_grad {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a) {
return DType(cosf(a));
}
};
struct log1p {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a) {
return DType(log1pf(a));
}
};
struct log1p_grad {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a) {
return DType(DType(1.0f) / (DType(1.0f) + a));
}
};
struct cos {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a) {
return DType(cosf(a));
}
};
struct cos_grad {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a) {
return DType(-sinf(a));
}
};
struct tan {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a) {
return DType(tanf(a));
}
};
struct tan_grad {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a) {
return DType(a * a + 1);
}
};
struct arcsin {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a) {
return DType(asinf(a));
}
};
struct arcsin_grad {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a) {
return DType(1.0 / (sqrtf(1 - a*a)));
}
};
struct arccos {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a) {
return DType(acosf(a));
}
};
struct arccos_grad {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a) {
return DType(-1.0 / (sqrtf(1 - a*a)));
}
};
struct arctan {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a) {
return DType(atanf(a));
}
};
struct arctan_grad {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a) {
return DType(1 / (a*a + 1));
}
};
struct hypot {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a, DType b) {
return DType(sqrtf(a * a + b * b));
}
};
struct hypot_grad_left {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a, DType b) {
return DType(a/sqrtf(a * a + b * b));
}
};
struct hypot_grad_right {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a, DType b) {
return DType(b/sqrtf(a * a + b * b));
}
};
struct degrees {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a) {
return DType(180. / PI * a);
}
};
struct degrees_grad {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a) {
return DType(180. / PI);
}
};
struct radians {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a) {
return DType(PI /180. * a);
}
};
struct radians_grad {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a) {
return DType(PI / 180.);
}
};
struct sinh {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a) {
return DType(sinhf(a));
}
};
struct sinh_grad {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a) {
return DType(coshf(a));
}
};
struct cosh {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a) {
return DType(coshf(a));
}
};
struct cosh_grad {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a) {
return DType(sinhf(a));
}
};
struct arcsinh {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a) {
return DType(asinhf(a));
}
};
struct arcsinh_grad {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a) {
return DType(1.0 / (sqrtf(1 + a*a)));
}
};
struct arccosh {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a) {
return DType(acoshf(a));
}
};
struct arccosh_grad {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a) {
return DType(1.0 / (sqrtf(a*a - 1.0)));
}
};
struct arctanh {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a) {
return DType(atanhf(a));
}
};
struct arctanh_grad {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a) {
return DType(-1.0 / (a*a - 1.0));
}
};
struct square {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a) {
return DType(a * a);
}
};
struct square_grad {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a) {
return DType(DType(2.0f) * a);
}
};
/*! \brief used for generate Bernoulli mask */
struct threshold {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a, DType b) {
return DType(a < b ? DType(1.0f) : DType(0.0f));
}
};
/*! \brief used for generate element of abs */
struct abs {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a) {
return DType(fabsf(float(a))); // NOLINT(*)
}
};
/*! \brief used for generate element of sign */
struct sign {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a) {
if (a < 0.0f) return DType(-DType(1.0f));
if (a > 0.0f) return DType(DType(1.0f));
return DType(DType(0.0f));
}
};
struct sign_grad {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a) {
return DType(DType(0.0f));
}
};
/*! \brief used for generate element of power */
struct power {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a, DType b) {
return DType(powf( a, b ));
}
};
struct power_grad {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a, DType b) {
return DType(powf( a, b - 1 )*b);
}
};
struct power_rgrad {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a, DType b) {
return DType(powf( a, b )*logf(a));
}
};
struct rpower {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a, DType b) {
return DType(powf( b, a ));
}
};
struct rpower_grad {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a, DType b) {
return DType(a*logf(b));
}
};
/*! \brief used for generate element of maximum */
struct maximum {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a, DType b) {
return a > b ? a : b;
}
};
/*! \brief used for generate element of minimum */
struct minimum {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a, DType b) {
return a < b ? a : b;
}
};
struct ge {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a, DType b) {
return a >= b ? DType(1) : DType(0);
}
};
struct gt {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a, DType b) {
return a > b ? DType(1) : DType(0);
}
};
struct lt {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a, DType b) {
return a < b ? DType(1) : DType(0);
}
};
struct le {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a, DType b) {
return a <= b ? DType(1) : DType(0);
}
};
struct eq {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a, DType b) {
return a == b ? DType(1) : DType(0);
}
};
struct ne {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a, DType b) {
return a != b ? DType(1) : DType(0);
}
};
/*!\ \brief used for generate element sqrt */
struct square_root {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a) {
return DType(sqrtf(a));
}
};
struct square_root_grad {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a) {
return DType(DType(0.5f) / a);
}
};
/*!\ \brief used for generate element sqrt */
struct reciprocal_square_root {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a) {
return DType(DType(1.0f)/sqrtf(a));
}
};
struct reciprocal_square_root_grad {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a) {
return DType(-(DType(1.0f) / (DType(2.0f) * a * sqrtf(a))));
}
};
/*! \brief used for generate element of round */
struct round {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a) {
return DType(roundf(a));
}
};
/*! \brief used for generate element of ceil */
struct ceil {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a) {
return DType(ceilf(a));
}
};
/*! \brief used for generate element of floor */
struct floor {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a) {
return DType(floorf(a));
}
};
/*! \brief used to round number to nearest integer */
struct rint {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a) {
float floor = floorf(a);
float ceil = ceilf(a);
return DType((floor - a) < (ceil - a) ? floor : ceil);
}
};
/*! \brief used to round number to integer nearest to 0 */
struct fix {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a) {
float floor = floorf(a);
float ceil = ceilf(a);
return DType((floor - 0) < (ceil - 0) ? floor : ceil);
}
};
/*! \brief used for generate gradient of MAE loss*/
struct minus_sign {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a, DType b) {
return DType(a-b > DType(0.0f) ? DType(1.0f) : -DType(1.0f));
}
};
struct rminus {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a, DType b) {
return DType(b-a);
}
};
struct div_grad {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a, DType b) {
return DType(DType(1)/b);
}
};
struct div_rgrad {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a, DType b) {
return DType(-a/(b*b));
}
};
struct rdiv {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a, DType b) {
return DType(b/a);
}
};
struct rdiv_grad {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a, DType b) {
return DType(-b/(a*a));
}
};
struct clip {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType x, DType bound) {
if (x > bound) {
return bound;
} else if (x < -bound) {
return -bound;
} else {
return x;
}
}
};
/***** gamma ******/
struct gamma {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a) {
// default implementation using floating precision
return DType(tgammaf(a));
}
};
template<>
MSHADOW_XINLINE double gamma::Map<double>(double a) {
return tgamma(a);
}
struct gamma_grad {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a) {
// default implementation using floating precision
return DType(tgammaf(a) * special_functions::cephes::psi<float>(a));
}
};
template<>
MSHADOW_XINLINE double gamma_grad::Map<double>(double a) {
return tgamma(a) * special_functions::cephes::psi<double>(a);
}
/***** gammaln ******/
struct gammaln {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a) {
// default implementation using floating precision
return DType(lgammaf(a));
}
};
template<>
MSHADOW_XINLINE double gammaln::Map<double>(double a) {
return lgamma(a);
}
struct gammaln_grad {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a) {
// default implementation using floating precision
return DType(special_functions::cephes::psi<float>(a));
}
};
template<>
MSHADOW_XINLINE double gammaln_grad::Map<double>(double a) {
return special_functions::cephes::psi<double>(a);
}
/* Smooth L1 Loss is a loss specific for R-CNN franchise training
* Smooth L1 Loss function
* f(x) = 0.5 * (sigma * x) ^ 2, x < 1 / sigma^2
* = |x| - 0.5 / sigma / sigma, otherwise
* When sigma = 1, it is equivalent to Huber Loss evaluated at
* delta = 1.
* smooth_l1_loss = w_out * f(w_in * x)
* with w_in, w_out provided by input_data.
*/
struct smooth_l1_loss {
// a is x, b is sigma2
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a, DType b) {
b *= b;
if (a > 1.0f / b) {
return a - 0.5f / b;
} else if (a < -1.0f / b) {
return -a - 0.5f / b;
} else {
return 0.5f * a * a * b;
}
}
}; // struct smooth_l1_loss
/* The derivative of smooth l1 loss is
* f'(x) = sigma^2 * x, x < 1 / sigma^2
* = sign(x), otherwise
*/
struct smooth_l1_gradient {
// a is x, b is sigma2
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a, DType b) {
b *= b;
if (a > 1.0f / b) {
return 1.0f;
} else if (a < -1.0f / b) {
return DType(-1);
} else {
return b * a;
}
}
}; // struct smooth_l1_derivative
/*! \brief product reducer */
struct product {
/*! \brief do reduction into dst */
template<typename DType>
MSHADOW_XINLINE static void Reduce(volatile DType& dst, volatile DType src) { // NOLINT(*)
dst *= src;
}
/*!
*\brief calculate gradient of redres with respect to redsrc,
* redres: reduced result, redsrc: one of reduction element
*/
template<typename DType>
MSHADOW_XINLINE static DType PartialGrad(DType redres, DType redsrc) {
return redres / redsrc;
}
/*!
*\brief set the initial value during reduction
*/
template<typename DType>
MSHADOW_XINLINE static void SetInitValue(DType &initv) { // NOLINT(*)
initv = 1;
}
};
namespace isnan_typed {
template<typename DType>
MSHADOW_XINLINE bool IsNan(volatile DType val) {
return false;
}
template<>
MSHADOW_XINLINE bool IsNan(volatile float val) {
return isnan(val);
}
template<>
MSHADOW_XINLINE bool IsNan(volatile double val) {
return isnan(val);
}
template<>
MSHADOW_XINLINE bool IsNan(volatile long double val) {
return isnan(val);
}
template<>
MSHADOW_XINLINE bool IsNan(volatile mshadow::half::half_t val) {
return (val.half_ & 0x7fff) > 0x7c00;
}
}; // namespace isnan_typed
/*! \brief sum reducer that ignores NaN values in the input */
struct nansum {
/*! \brief do reduction into dst */
template<typename DType>
MSHADOW_XINLINE static void Reduce(volatile DType& dst, volatile DType src) { // NOLINT(*)
if (isnan_typed::IsNan(dst)) {
if (isnan_typed::IsNan(src)) {
dst = DType(0);
} else {
dst = src;
}
} else {
if (isnan_typed::IsNan(src)) {
dst = dst;
} else {
dst += src;
}
}
}
/*!
*\brief set the initial value during reduction
*/
template<typename DType>
MSHADOW_XINLINE static void SetInitValue(DType & initv) { // NOLINT(*)
initv = 0;
}
};
struct nansum_grad {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a, DType b) {
return isnan_typed::IsNan(a) ? DType(0) : DType(1);
}
};
/*! \brief product reducer that ignores NaN values in the input */
struct nanprod {
/*! \brief do reduction into dst */
template<typename DType>
MSHADOW_XINLINE static void Reduce(volatile DType& dst, volatile DType src) { // NOLINT(*)
if (isnan_typed::IsNan(dst)) {
if (isnan_typed::IsNan(src)) {
dst = DType(1);
} else {
dst = src;
}
} else {
if (isnan_typed::IsNan(src)) {
dst = dst;
} else {
dst *= src;
}
}
}
/*!
*\brief set the initial value during reduction
*/
template<typename DType>
MSHADOW_XINLINE static void SetInitValue(DType & initv) { // NOLINT(*)
initv = 1;
}
};
struct nanprod_grad {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a, DType b) {
return isnan_typed::IsNan(a) ? DType(0) : b / a;
}
};
} // namespace mshadow_op
} // namespace op
} // namespace mxnet
#endif // MXNET_OPERATOR_MSHADOW_OP_H_
//===== EXPANDED: ../src/operator/mshadow_op.h =====
namespace mxnet {
/*! \brief namespace to support all possible Ndarray operator */
namespace ndarray {
struct BinaryBase {
inline static TShape GetShape(const TShape &lshape, const TShape &rshape) {
CHECK(lshape == rshape) << "operands shape mismatch";
CHECK(lshape.ndim() != 0) << "source operand have zero dimension shape";
return lshape;
}
};
// operators
struct Plus : public BinaryBase {
typedef mshadow::op::plus mshadow_op;
};
struct Minus : public BinaryBase {
typedef mshadow::op::minus mshadow_op;
};
struct Mul : public BinaryBase {
typedef mshadow::op::mul mshadow_op;
};
struct Div : public BinaryBase {
typedef mshadow::op::div mshadow_op;
};
struct ClipMin : public BinaryBase {
struct mshadow_op {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a, DType b) {
if (a < b) {
return b;
} else {
return a;
}
}
};
};
struct ClipMax : public BinaryBase {
struct mshadow_op {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a, DType b) {
if (a > b) {
return b;
} else {
return a;
}
}
};
};
struct OneHotEncode {
inline static TShape GetShape(const TShape &index, const TShape &proptype) {
CHECK(index.ndim() == 1 && proptype.ndim() == 2) << "OneHotEncode only support 1d index.";
CHECK_EQ(index[0], proptype[0]) << "OneHotEncode shape inconsistent";
return proptype;
}
};
struct MatChooseRowElem {
inline static TShape GetShape(const TShape &lshape, const TShape &rshape) {
CHECK(lshape.ndim() == 2 && rshape.ndim() == 1)
<< "choose_row_element only support 2D Matrix and 1D index";
CHECK_EQ(lshape[0], rshape[0]) << "choose_row_element index and matrix shape mismatch";
return rshape;
}
};
struct MatFillRowElem {
inline static TShape GetShape(const TShape &lshape, const TShape &mshape, const TShape &rshape) {
CHECK(lshape.ndim() == 2 && mshape.ndim() == 1 && rshape.ndim() == 1)
<< "fill_row_element only support 2D Matrix, 1D value and 1D index";
CHECK((lshape[0] == mshape[0]) && (mshape[0] == rshape[0]))
<< "choose_row_element index vector, value vector and matrix shape mismatch";
return lshape;
}
};
// type holder for random number generators
struct UniformDistribution {};
struct GaussianDistribution {};
template<typename Device>
void EvalClip(const TBlob &src, const real_t &a_min, const real_t &a_max,
TBlob *ret, RunContext ctx);
template<typename Device, typename OP>
void Eval(const TBlob &lhs, const TBlob &mhs, const TBlob &rhs, TBlob *ret, RunContext ctx);
template<typename Device, typename OP>
void Eval(const TBlob &lhs, const TBlob &rhs, TBlob *ret, RunContext ctx);
template<typename Device, typename OP>
void Eval(const TBlob &src, TBlob *ret, RunContext ctx);
template<typename Device, typename OP, bool reverse>
void Eval(const TBlob &lhs, const real_t &rhs, TBlob *ret, RunContext ctx);
template<typename Device>
void Eval(const real_t &rhs, TBlob *ret, RunContext ctx);
template<typename Device, typename Distribution>
void EvalRandom(const real_t &a,
const real_t &b,
const Resource &resource,
TBlob *ret, RunContext ctx);
// copy function when only cpu is involved
template<typename DeviceFrom, typename DeviceTo>
void Copy(const TBlob &from, TBlob *to,
Context from_ctx, Context to_ctx,
RunContext ctx);
template<typename Device>
void ElementwiseSum(const std::vector<TBlob> source,
TBlob *out,
RunContext ctx);
// broadcasting
template <typename Device>
void EvalBroadcast(TBlob const& src, TBlob* ret, int size, RunContext ctx);
} // namespace ndarray
} // namespace mxnet
#endif // MXNET_NDARRAY_NDARRAY_FUNCTION_H_
//===== EXPANDED: ../src/ndarray/ndarray_function.h =====
//===== EXPANDING: ../src/ndarray/ndarray_function-inl.h =====
/*!
* Copyright (c) 2015 by Contributors
* \file ndarray_function-inl.h
* \brief The real implementation of NDArray functions.
*/
#ifndef MXNET_NDARRAY_NDARRAY_FUNCTION_INL_H_
#define MXNET_NDARRAY_NDARRAY_FUNCTION_INL_H_
// this file will be included twice by CPU and GPU
// macro to help specialize evaluation function
#ifndef DECL_TERNARY
#define DECL_TERNARY(XPU, OP, FUN) \
template<> \
void Eval<XPU, OP>(const TBlob &lhs, const TBlob &mhs, \
const TBlob &rhs, TBlob *ret, RunContext ctx) { \
FUN<XPU, OP>(lhs, mhs, rhs, ret, ctx); \
}
#endif
#ifndef DECL_BINARY
#define DECL_BINARY(XPU, OP, FUN) \
template<> \
void Eval<XPU, OP>(const TBlob &lhs, const TBlob &rhs, TBlob *ret, RunContext ctx) { \
FUN<XPU, OP>(lhs, rhs, ret, ctx); \
}
#endif
#ifndef DECL_SCALAR
#define DECL_SCALAR(XPU, OP, FUN, REVERSE) \
template<> \
void Eval<XPU, OP, REVERSE>(const TBlob &lhs, const real_t &rhs, TBlob *ret, RunContext ctx) { \
FUN<XPU, OP, REVERSE>(lhs, rhs, ret, ctx); \
}
#endif
#if defined(__CUDACC__)
#define DEVICE gpu
#else
#define DEVICE cpu
#endif
namespace mxnet {
namespace ndarray {
// true implementation
template<typename xpu, typename OP>
inline void EvalBinary_(const TBlob &lhs, const TBlob &rhs,
TBlob *ret, RunContext ctx) {
using namespace mshadow::expr;
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
CHECK_EQ(ret->type_flag_, lhs.type_flag_)
<< "Only support input/output with the same data type";
CHECK_EQ(ret->type_flag_, rhs.type_flag_)
<< "Only support input/output with the same data type";
MSHADOW_TYPE_SWITCH(ret->type_flag_, DType, {
ret->FlatTo2D<xpu, DType>(s)
= F<typename OP::mshadow_op>(lhs.FlatTo2D<xpu, DType>(s),
rhs.FlatTo2D<xpu, DType>(s));
});
}
template<typename xpu, typename OP>
inline void EvalOneHot_(const TBlob &index, const TBlob &rhs,
TBlob *ret, RunContext ctx) {
using namespace mshadow::expr;
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
// TODO(eric): support mixed type encoding, i.e. int index and float rhs.
CHECK_EQ(ret->type_flag_, mshadow::default_type_flag)
<< "one_hot_encode only support float32 as input/output";
CHECK_EQ(rhs.type_flag_, mshadow::default_type_flag)
<< "one_hot_encode only support float32 as input/output";
CHECK_EQ(index.type_flag_, mshadow::default_type_flag)
<< "one_hot_encode only support float32 as input/output";
ret->get<xpu, 2, real_t>(s) =
one_hot_encode(index.get<xpu, 1, real_t>(s),
rhs.shape_[1]);
}
template<typename xpu, typename OP>
inline void EvalMatChooseRowElem_(const TBlob &lhs, const TBlob &rhs,
TBlob *ret, RunContext ctx) {
using namespace mshadow::expr;
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
// TODO(eric): support mixed type choose, i.e. int index and float rhs.
CHECK_EQ(ret->type_flag_, mshadow::default_type_flag)
<< "mat_choose_row_element only support float32 as input/output";
CHECK_EQ(rhs.type_flag_, mshadow::default_type_flag)
<< "mat_choose_row_element only support float32 as input/output";
CHECK_EQ(lhs.type_flag_, mshadow::default_type_flag)
<< "mat_choose_row_element only support float32 as input/output";
ret->get<xpu, 1, real_t>(s)
= mat_choose_row_element(lhs.get<xpu, 2, real_t>(s),
rhs.get<xpu, 1, real_t>(s));
}
template<typename xpu, typename OP>
inline void EvalMatFillRowElem_(const TBlob &lhs, const TBlob &mhs, const TBlob &rhs,
TBlob *ret, RunContext ctx) {
using namespace mshadow::expr;
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
ret->get<xpu, 2, real_t>(s)
= mat_fill_row_element(lhs.get<xpu, 2, real_t>(s),
mhs.get<xpu, 1, real_t>(s),
rhs.get<xpu, 1, real_t>(s));
}
template<typename xpu, typename OP, bool reverse>
inline void EvalScalar_(const TBlob &lhs, const real_t &rhs,
TBlob *ret, RunContext ctx) {
using namespace mshadow::expr;
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
CHECK_EQ(ret->type_flag_, lhs.type_flag_)
<< "Only support input/output with the same data type";
if (reverse) {
MSHADOW_TYPE_SWITCH(ret->type_flag_, DType, {
ret->FlatTo2D<xpu, DType>(s)
= F<typename OP::mshadow_op>(scalar(DType(rhs)), lhs.FlatTo2D<xpu, DType>(s));
});
} else {
MSHADOW_TYPE_SWITCH(ret->type_flag_, DType, {
ret->FlatTo2D<xpu, DType>(s)
= F<typename OP::mshadow_op>(lhs.FlatTo2D<xpu, DType>(s), scalar(DType(rhs)));
});
}
}
template<>
void EvalClip<DEVICE>(const TBlob &src, const real_t &a_min, const real_t &a_max,
TBlob *ret, RunContext ctx) {
typedef DEVICE xpu;
using namespace mshadow::expr;
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
CHECK_EQ(ret->type_flag_, src.type_flag_)
<< "Only support input/output with the same data type";
MSHADOW_TYPE_SWITCH(ret->type_flag_, DType, {
ret->FlatTo2D<xpu, DType>(s)
= F<ClipMax::mshadow_op>(
F<ClipMin::mshadow_op>(src.FlatTo2D<xpu, DType>(s), scalar(DType(a_min))),
scalar(DType(a_max)));
});
}
template<>
void EvalRandom<DEVICE, UniformDistribution>(
const real_t &a,
const real_t &b,
const Resource &resource,
TBlob *ret,
RunContext ctx) {
typedef DEVICE xpu;
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
switch (ret->type_flag_) {
case mshadow::kFloat32:
{
mshadow::Random<xpu, float> *prnd = resource.get_random<xpu, float>(s);
mshadow::Tensor<xpu, 2, float> tmp = ret->FlatTo2D<xpu, float>(s);
prnd->SampleUniform(&tmp, float(a), float(b)); // NOLINT(*)
break;
}
case mshadow::kFloat64:
{
mshadow::Random<xpu, double> *prnd = resource.get_random<xpu, double>(s);
mshadow::Tensor<xpu, 2, double> tmp = ret->FlatTo2D<xpu, double>(s);
prnd->SampleUniform(&tmp, double(a), double(b)); // NOLINT(*)
break;
}
default:
LOG(FATAL) << "Random only support float32 and float64";
}
}
template<>
void EvalRandom<DEVICE, GaussianDistribution>(
const real_t &mu,
const real_t &sigma,
const Resource &resource,
TBlob *ret,
RunContext ctx) {
typedef DEVICE xpu;
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
switch (ret->type_flag_) {
case mshadow::kFloat32:
{
mshadow::Random<xpu, float> *prnd = resource.get_random<xpu, float>(s);
mshadow::Tensor<xpu, 2, float> tmp = ret->FlatTo2D<xpu, float>(s);
prnd->SampleGaussian(&tmp, float(mu), float(sigma)); // NOLINT(*)
break;
}
case mshadow::kFloat64:
{
mshadow::Random<xpu, double> *prnd = resource.get_random<xpu, double>(s);
mshadow::Tensor<xpu, 2, double> tmp = ret->FlatTo2D<xpu, double>(s);
prnd->SampleGaussian(&tmp, double(mu), double(sigma)); // NOLINT(*)
break;
}
default:
LOG(FATAL) << "Random only support float32 and float64";
}
}
template<>
void Eval<DEVICE>(const real_t &rhs, TBlob *ret, RunContext ctx) {
mshadow::Stream<DEVICE> *s = ctx.get_stream<DEVICE>();
MSHADOW_TYPE_SWITCH(ret->type_flag_, DType, {
ret->FlatTo2D<DEVICE, DType>(s) = DType(rhs);
});
}
template<>
void ElementwiseSum<DEVICE>(const std::vector<TBlob> source,
TBlob *dst,
RunContext ctx) {
typedef DEVICE xpu;
using namespace mshadow;
using namespace mshadow::expr;
Stream<xpu> *s = ctx.get_stream<xpu>();
for (size_t i = 1; i < source.size(); ++i) {
CHECK_EQ(source[i].type_flag_, dst->type_flag_)
<< "Only support input/output with the same data type";
}
MSHADOW_TYPE_SWITCH(dst->type_flag_, DType, {
Tensor<xpu, 2, DType> out = dst->FlatTo2D<xpu, DType>(s);
switch (source.size()) {
case 2: {
Tensor<xpu, 2, DType> in_0 = source[0].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, DType> in_1 = source[1].FlatTo2D<xpu, DType>(s);
out = in_0 + in_1;
break;
}
case 3: {
Tensor<xpu, 2, DType> in_0 = source[0].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, DType> in_1 = source[1].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, DType> in_2 = source[2].FlatTo2D<xpu, DType>(s);
out = in_0 + in_1 + in_2;
break;
}
case 4: {
Tensor<xpu, 2, DType> in_0 = source[0].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, DType> in_1 = source[1].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, DType> in_2 = source[2].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, DType> in_3 = source[3].FlatTo2D<xpu, DType>(s);
out = in_0 + in_1 + in_2 + in_3;
break;
}
default: {
Tensor<xpu, 2, DType> in_0 = source[0].FlatTo2D<xpu, DType>(s);
out = F<mshadow::op::identity>(in_0);
for (size_t i = 1; i < source.size(); ++i) {
out += source[i].FlatTo2D<xpu, DType>(s);
}
break;
}
}
});
}
template <>
void EvalBroadcast<DEVICE>(TBlob const& src, TBlob* ret, int size, RunContext ctx) {
typedef DEVICE xpu;
mshadow::Stream<xpu>* s = ctx.get_stream<xpu>();
mshadow::Tensor<xpu, 3> out = ret->get<xpu, 3, real_t>(s);
mshadow::Tensor<xpu, 2> in = src.get<xpu, 2, real_t>(s);
out = mshadow::expr::broadcast_with_axis(in, 0, size);
}
// declarations
DECL_BINARY(DEVICE, MatChooseRowElem, EvalMatChooseRowElem_)
DECL_TERNARY(DEVICE, MatFillRowElem, EvalMatFillRowElem_)
DECL_BINARY(DEVICE, OneHotEncode, EvalOneHot_)
DECL_BINARY(DEVICE, Plus, EvalBinary_)
DECL_BINARY(DEVICE, Minus, EvalBinary_)
DECL_BINARY(DEVICE, Mul, EvalBinary_)
DECL_BINARY(DEVICE, Div, EvalBinary_)
DECL_SCALAR(DEVICE, Plus, EvalScalar_, true)
DECL_SCALAR(DEVICE, Minus, EvalScalar_, true)
DECL_SCALAR(DEVICE, Mul, EvalScalar_, true)
DECL_SCALAR(DEVICE, Div, EvalScalar_, true)
// for reverse seq
DECL_SCALAR(DEVICE, Plus, EvalScalar_, false)
DECL_SCALAR(DEVICE, Minus, EvalScalar_, false)
DECL_SCALAR(DEVICE, Mul, EvalScalar_, false)
DECL_SCALAR(DEVICE, Div, EvalScalar_, false)
} // namespace ndarray
} // namespace mxnet
#endif // MXNET_NDARRAY_NDARRAY_FUNCTION_INL_H_
//===== EXPANDED: ../src/ndarray/ndarray_function-inl.h =====
namespace mxnet {
namespace ndarray {
template<>
void Copy<cpu, cpu>(const TBlob &from, TBlob *to,
Context from_ctx, Context to_ctx,
RunContext ctx) {
MSHADOW_TYPE_SWITCH(to->type_flag_, DType, {
if (to->type_flag_ == from.type_flag_) {
mshadow::Copy(to->FlatTo1D<cpu, DType>(),
from.FlatTo1D<cpu, DType>());
} else {
MSHADOW_TYPE_SWITCH(from.type_flag_, SrcDType, {
to->FlatTo1D<cpu, DType>() =
mshadow::expr::tcast<DType>(from.FlatTo1D<cpu, SrcDType>());
})
}
})
}
} // namespace ndarray
} // namespace mxnet
//===== EXPANDED: ../src/ndarray/ndarray_function.cc =====
//===== EXPANDING: ../src/ndarray/ndarray.cc =====
/*!
* Copyright (c) 2015 by Contributors
* \file ndarray.cc
* \brief ndarry module of mxnet
*/
//===== EXPANDING: ../include/mxnet/ndarray.h =====
/*!
* Copyright (c) 2015 by Contributors
* \file ndarray.h
* \brief NDArray interface that handles array arithematics.
*/
#ifndef MXNET_NDARRAY_H_
#define MXNET_NDARRAY_H_
//===== EXPANDING: ../include/mxnet/storage.h =====
/*!
* Copyright (c) 2015 by Contributors
* \file storage.h
* \brief Storage manager across multiple devices.
*/
#ifndef MXNET_STORAGE_H_
#define MXNET_STORAGE_H_
namespace mxnet {
/*!
* \brief Storage manager across multiple devices.
*/
class Storage {
public:
/*!
* \brief Storage handle.
*/
struct Handle {
/*!
* \brief Pointer to the data.
*/
void* dptr;
/*!
* \brief Size of the storage.
*/
size_t size;
/*!
* \brief Context information about device and ID.
*/
Context ctx;
};
/*!
* \brief Allocate a new contiguous memory for a given size.
* \param size Total size of memory in bytes.
* \param ctx Context information about the device and ID.
* \return Handle struct.
*/
virtual Handle Alloc(size_t size, Context ctx) = 0;
/*!
* \brief Free storage.
* \param handle Handle struect.
*/
virtual void Free(Handle handle) = 0;
/*!
* \brief Free storage directly, without putting it into memory pool.
* This can synchronization of all previous runned device functions.
*
* This function is suitable for conatiner structure with requirement on upsizing
* in the beginning phase of the iteration.
*
* \param handle Handle struct.
*/
virtual void DirectFree(Handle handle) = 0;
/*!
* \brief Destructor.
*/
virtual ~Storage() {}
/*!
* \return Storage singleton.
*/
static Storage* Get();
/*!
* \brief Get shared pointer reference to engine singleton.
* Most user should not call this function.
* This function is called by another singleton X who requires
* Storage to be destructed after X.
*
* \return A shared pointer to Storage singleton.
*/
static std::shared_ptr<Storage> _GetSharedRef();
}; // class Storage
} // namespace mxnet
#endif // MXNET_STORAGE_H_
//===== EXPANDED: ../include/mxnet/storage.h =====
#if MKL_EXPERIMENTAL == 1
#endif
// check c++11
#if DMLC_USE_CXX11 == 0
#error "cxx11 was required for ndarray module"
#endif
namespace mxnet {
/*!
* \brief ndarray interface
*/
class NDArray {
public:
/*! \brief default cosntructor */
NDArray() {
#if MKL_EXPERIMENTAL == 1
Mkl_mem_ = MKLMemHolder::create();
#endif
}
/*!
* \brief constructing a new dynamic NDArray
* \param shape the shape of array
* \param ctx context of NDArray
* \param delay_alloc whether delay the allocation
* \param dtype data type of this ndarray
*/
NDArray(const TShape &shape, Context ctx,
bool delay_alloc = false, int dtype = mshadow::default_type_flag)
: ptr_(std::make_shared<Chunk>(shape.Size(), ctx, delay_alloc, dtype)),
shape_(shape), offset_(0), dtype_(dtype) {
#if MKL_EXPERIMENTAL == 1
Mkl_mem_ = std::make_shared<MKLMemHolder>();
#endif
}
/*!
* \brief constructing a static NDArray that shares data with TBlob
* Use with caution: allocate ONLY ONE NDArray for each TBlob,
* make sure the memory region is available through out the life of NDArray
* \param data the memory content of static data
* \param dev_id the device id this tensor sits at
*/
NDArray(const TBlob &data, int dev_id)
: ptr_(std::make_shared<Chunk>(data, dev_id)), shape_(data.shape_), offset_(0),
dtype_(data.type_flag_) {
#if MKL_EXPERIMENTAL == 1
Mkl_mem_ = std::make_shared<MKLMemHolder>();
#endif
}
/*!
* \return the shape of current NDArray
*/
inline const TShape &shape() const {
return shape_;
}
/*!
* \return the data TBlob
*/
inline TBlob data() const {
TBlob res;
MSHADOW_TYPE_SWITCH(dtype_, DType, {
res = TBlob(static_cast<DType*>(ptr_->shandle.dptr)
+ offset_, shape_, ptr_->shandle.ctx.dev_mask());
});
#if MKL_EXPERIMENTAL == 1
res.Mkl_mem_ = Mkl_mem_;
#endif
return res;
}
/*!
* \return a chunk of raw data in TBlob
*/
inline TBlob raw_data(index_t offset, index_t length) const {
TBlob res;
TShape raw_shape(1);
raw_shape[0] = length;
MSHADOW_TYPE_SWITCH(dtype_, DType, {
res = TBlob(static_cast<DType*>(ptr_->shandle.dptr)
+ offset_ + offset, raw_shape, ptr_->shandle.ctx.dev_mask());
});
#if MKL_EXPERIMENTAL == 1
res.Mkl_mem_ = Mkl_mem_;
#endif
return res;
}
/*!
* \return the context of NDArray, this function is only valid when the NDArray is not empty
*/
inline Context ctx() const {
return ptr_->shandle.ctx;
}
/*!
* \return the data type of NDArray, this function is only valid when the NDArray is not empty
*/
inline int dtype() const {
return dtype_;
}
/*! \return whether this ndarray is not initialized */
inline bool is_none() const {
return ptr_.get() == nullptr;
}
/*!
* \brief Block until all the pending write operations with respect
* to current NDArray are finished, and read can be performed.
*/
inline void WaitToRead() const {
if (is_none()) return;
Engine::Get()->WaitForVar(ptr_->var);
}
/*!
* \brief Block until all the pending read/write operations with respect
* to current NDArray are finished, and write can be performed.
*/
inline void WaitToWrite() const {
if (is_none()) return;
/*!
* Push an empty mutable function to flush all preceding reads to the
* variable.
*/
Engine::Get()->PushSync([](RunContext) {}, Context{}, {}, {ptr_->var});
Engine::Get()->WaitForVar(ptr_->var);
}
/*! \return the associated variable of the ndarray.*/
inline Engine::VarHandle var() const {
return ptr_->var;
}
/*!
* \brief save the content into binary stream
* \param strm the output stream
*/
void Save(dmlc::Stream *strm) const;
/*!
* \brief load the content from binary stream
* \param strm the output stream
* \return whether the load is successful
*/
bool Load(dmlc::Stream *strm);
/*!
* \brief set all the elements in ndarray to be scalar
* \param scalar the scalar to set
* \return reference of self
*/
NDArray &operator=(real_t scalar);
/*!
* \brief elementwise add to current space
* this mutate the current NDArray
* \param src the data to add
* \return reference of self
*/
NDArray &operator+=(const NDArray &src);
/*!
* \brief elementwise add to current space
* this mutate the current NDArray
* \param src the data to add
* \return reference of self
*/
NDArray &operator+=(const real_t &src);
/*!
* \brief elementwise subtract from current ndarray
* this mutate the current NDArray
* \param src the data to subtract
* \return reference of self
*/
NDArray &operator-=(const NDArray &src);
/*!
* \brief elementwise subtract from current ndarray
* this mutate the current NDArray
* \param src the data to subtract
* \return reference of self
*/
NDArray &operator-=(const real_t &src);
/*!
* \brief elementwise multiplication to current ndarray
* this mutate the current NDArray
* \param src the data to subtract
* \return reference of self
*/
NDArray &operator*=(const NDArray &src);
/*!
* \brief elementwise multiplication to current ndarray
* this mutate the current NDArray
* \param src the data to subtract
* \return reference of self
*/
NDArray &operator*=(const real_t &src);
/*!
* \brief elementwise division from current ndarray
* this mutate the current NDArray
* \param src the data to subtract
* \return reference of self
*/
NDArray &operator/=(const NDArray &src);
/*!
* \brief elementwise division from current ndarray
* this mutate the current NDArray
* \param src the data to subtract
* \return reference of self
*/
NDArray &operator/=(const real_t &src);
/*!
* \brief return transpose of current NDArray
* \return a new transposed NDArray
*/
NDArray T() const;
/*!
* \brief return a new copy this NDArray
* \param ctx the new context of this NDArray
* \return the new copy
*/
NDArray Copy(Context ctx) const;
/*!
* \brief Do a synchronize copy from a continugous CPU memory region.
*
* This function will call WaitToWrite before the copy is performed.
* This is useful to copy data from existing memory region that are
* not wrapped by NDArray(thus dependency not being tracked).
*
* \param data the data source to copy from.
* \param size the size of the source array, in sizeof(DType) not raw btyes.
*/
void SyncCopyFromCPU(const void *data, size_t size) const;
/*!
* \brief Do a synchronize copy to a continugous CPU memory region.
*
* This function will call WaitToRead before the copy is performed.
* This is useful to copy data from existing memory region that are
* not wrapped by NDArray(thus dependency not being tracked).
*
* \param data the data source to copyinto.
* \param size the memory size we want to copy into, in sizeof(DType) not raw btyes.
*/
void SyncCopyToCPU(void *data, size_t size) const;
/*!
* \brief Slice a NDArray
* \param begin begin index in first dim
* \param end end index in first dim
* \return sliced NDArray
*/
inline NDArray Slice(index_t begin, index_t end) const {
NDArray ret = *this;
CHECK(!is_none()) << "NDArray is not initialized";
CHECK_GE(shape_[0], end) << "Slice end index out of range";
size_t length = shape_.ProdShape(1, shape_.ndim());
ret.offset_ += begin * length;
ret.shape_[0] = end - begin;
return ret;
}
/*!
* \brief Index a NDArray
* \param idx the index
* \return idx-th sub array NDArray
*/
inline NDArray At(index_t idx) const {
NDArray ret = *this;
CHECK(!is_none()) << "NDArray is not initialized";
CHECK_GT(shape_[0], idx) << "index out of range";
size_t length = shape_.ProdShape(1, shape_.ndim());
ret.offset_ += idx * length;
if (shape_.ndim() > 1) {
ret.shape_ = TShape(shape_.data()+1, shape_.data()+shape_.ndim());
} else {
ret.shape_ = mshadow::Shape1(1);
}
return ret;
}
/*!
* \brief Create a NDArray that shares memory with current one
* The new array must have smaller memory size than the current array.
* \param shape new shape
* \param dtype The data type.
* \return NDArray in new shape and type.
*/
inline NDArray AsArray(const TShape &shape, int dtype) const {
CHECK_GE(shape_.Size() * mshadow::mshadow_sizeof(dtype_),
shape.Size() * mshadow::mshadow_sizeof(dtype))
<< "NDArray.AsArray: target memory size is bigger";
#if MKL_EXPERIMENTAL == 1
if (Mkl_mem_ != nullptr) {
// convert prv to cpu
Mkl_mem_->check_and_prv_to_cpu(ptr_->shandle.dptr);
}
#endif
NDArray ret = *this;
ret.shape_ = shape;
ret.dtype_ = dtype;
return ret;
}
/*!
* \brief Get an reshaped NDArray
* \param shape new shape
* \return NDArray in new shape
*/
inline NDArray Reshape(const TShape &shape) const {
CHECK_GE(shape_.Size(), shape.Size())
<< "NDArray.Reshape: target shape size is different from current shape";
NDArray ret = *this;
ret.shape_ = shape;
return ret;
}
/*!
* \brief Allocate the space if it is delayed allocated.
* This is an internal function used by system that normal user should not use
*/
inline void CheckAndAlloc() const {
ptr_->CheckAndAlloc();
}
/*!
* \brief Save list of narray into the Stream.x
* \param fo The stream of output.
* \param data the NDArrays to be saved.
* \param names the name of the NDArray, optional, can be zero length.
*/
static void Save(dmlc::Stream* fo,
const std::vector<NDArray>& data,
const std::vector<std::string>& names);
/*!
* \brief Load list of narray into from th
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment