Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Swig multi device #1574

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 11 additions & 9 deletions contrib/swig/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,30 +21,30 @@ set(SCALA_VERSION "2.11.11" CACHE STRING "Scala version to compile the swig wrap
string(REGEX MATCH "[0-9]+\\.[0-9]+" SCALA_BIN_VERSION "${SCALA_VERSION}")

# Set java package (+cuda flag, if appropriate)
if(WITH_CUDA_BACKEND)
set(CMAKE_SWIG_FLAGS -package edu.cmu.dynet.internal -DSWIG_USE_CUDA)
else(WITH_CUDA_BACKEND)
set(CMAKE_SWIG_FLAGS -package edu.cmu.dynet.internal)
endif(WITH_CUDA_BACKEND)
set(CMAKE_SWIG_FLAGS -package edu.cmu.dynet.internal -Dfinal)

# Run swig
set_source_files_properties(dynet_swig.i PROPERTIES CPLUSPLUS ON)
swig_add_module(dynet_swig java dynet_swig.i)
if(${CMAKE_VERSION} VERSION_LESS "3.8.0")
swig_add_module(dynet_swig java dynet_swig.i)
else()
swig_add_library(dynet_swig LANGUAGE java SOURCES dynet_swig.i)
endif()

# add C++ compiler flags
if(WITH_CUDA_BACKEND)
set_target_properties(dynet_swig PROPERTIES
COMPILE_DEFINITIONS HAVE_CUDA)
endif(WITH_CUDA_BACKEND)
endif()

# Link with dynet library
if(WITH_CUDA_BACKEND)
MESSAGE("-- swig link with GPU library")
swig_link_libraries(dynet_swig dynet)
else(WITH_CUDA_BACKEND)
else()
MESSAGE("-- swig link with CPU library")
swig_link_libraries(dynet_swig dynet)
endif(WITH_CUDA_BACKEND)
endif()

# Create jar file
add_jar(
Expand All @@ -60,6 +60,7 @@ add_jar(
"${CMAKE_SWIG_OUTDIR}/CompactVanillaLSTMBuilder.java"
"${CMAKE_SWIG_OUTDIR}/CyclicalSGDTrainer.java"
"${CMAKE_SWIG_OUTDIR}/Device.java"
"${CMAKE_SWIG_OUTDIR}/DeviceManager.java"
"${CMAKE_SWIG_OUTDIR}/DeviceMempool.java"
"${CMAKE_SWIG_OUTDIR}/DeviceMempoolSizes.java"
"${CMAKE_SWIG_OUTDIR}/DeviceType.java"
Expand Down Expand Up @@ -112,6 +113,7 @@ add_jar(
"${CMAKE_SWIG_OUTDIR}/SWIGTYPE_p_p_p_char.java"
"${CMAKE_SWIG_OUTDIR}/SWIGTYPE_p_size_t.java"
"${CMAKE_SWIG_OUTDIR}/SWIGTYPE_p_std__vectorT_dynet__AlignedMemoryPool_p_t.java"
"${CMAKE_SWIG_OUTDIR}/SWIGTYPE_p_std__vectorT_dynet__Device_p_t.java"
"${CMAKE_SWIG_OUTDIR}/SWIGTYPE_p_std__vectorT_dynet__Node_p_t.java"
"${CMAKE_SWIG_OUTDIR}/SWIGTYPE_p_std__vectorT_dynet__Tensor_t.java"
"${CMAKE_SWIG_OUTDIR}/SWIGTYPE_p_std__vectorT_dynet__VariableIndex_t.java"
Expand Down
30 changes: 15 additions & 15 deletions contrib/swig/build.sbt
Original file line number Diff line number Diff line change
@@ -1,26 +1,27 @@
lazy val root = (project in file("."))
.settings(
name := "dynet_scala_helpers",
organization := "edu.cmu.dynet",
version := "0.0.1-SNAPSHOT"
)
.settings(
name := "dynet_scala_helpers",
organization := "edu.cmu.dynet",
version := "0.0.1-SNAPSHOT"
)

val DEFAULT_BUILD_PATH = "../../build/contrib/swig"

// The default scala version to use if none was specified from
// outside. When building with cmake, the scalaversion property
// should always be set; this is only a fallback for other cases.
val DEFAULT_SCALA_VERSION = "2.11.11"
val DEFAULT_SCALA_VERSION = "2.12.8"

scalaVersion := { sys.props.get("scalaversion") match {
scalaVersion := {
sys.props.get("scalaversion") match {
case Some(p) => p
case None => {
case None =>
println(s"using default scala version ${DEFAULT_SCALA_VERSION}")
DEFAULT_SCALA_VERSION
}
}}

}
}

javaOptions in Test ++= Seq("-Xms1G","-XX:+CMSClassUnloadingEnabled","-XX:+UseConcMarkSweepGC")

// This is where `make` does all its work, and it's where we'll do all our work as well.

Expand All @@ -29,10 +30,9 @@ lazy val buildPath = settingKey[String]("Build Path")
buildPath := {
val bp = sys.props.get("buildpath") match {
case Some(p) => p
case None => {
case None =>
println(s"using default buildpath ${DEFAULT_BUILD_PATH}")
DEFAULT_BUILD_PATH
}
}
if (new File(bp).exists) {
bp
Expand Down Expand Up @@ -93,6 +93,6 @@ assemblyMergeStrategy in assembly := {
// Don't include Scala libraries in the jar
// see https://github.com/sbt/sbt-assembly/issues/3
// and http://stackoverflow.com/questions/15856739/assembling-a-jar-containing-only-the-provided-dependencies
assembleArtifact in packageScala := false
assembleArtifact in assemblyPackageScala := false

libraryDependencies += "org.scalatest" %% "scalatest" % "3.0.0" % "test"
libraryDependencies += "org.scalatest" %% "scalatest" % "3.0.8" % "test"
92 changes: 71 additions & 21 deletions contrib/swig/dynet_swig.i
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,6 @@ VECTORCONSTRUCTOR(std::vector<unsigned>, UnsignedVector, UnsignedVectorVector)
VECTORCONSTRUCTOR(std::vector<dynet::Expression>, ExpressionVector, ExpressionVectorVector)
VECTORCONSTRUCTOR(std::vector<dynet::Parameter>, ParameterVector, ParameterVectorVector)


// Useful SWIG libraries
%include "std_vector.i"
%include "std_string.i"
Expand All @@ -107,11 +106,17 @@ VECTORCONSTRUCTOR(std::vector<dynet::Parameter>, ParameterVector, ParameterVecto

%shared_ptr(dynet::ParameterStorage)
%shared_ptr(dynet::LookupParameterStorage)
%shared_ptr(dynet::ParameterStorageBase)

// Convert C++ exceptions into Java exceptions. This provides
// nice error messages for each listed exception, and a default
// "unknown error" message for all others.
%catches(std::invalid_argument, ...);
%catches(std::invalid_argument,
std::runtime_error,
std::domain_error,
dynet::out_of_memory,
dynet::cuda_exception,
...);

%pointer_functions(unsigned, uintp);
%pointer_functions(int, intp);
Expand Down Expand Up @@ -155,6 +160,8 @@ struct Node;
struct ParameterStorage;
struct LookupParameterStorage;

struct Device;

///////////////////////////////////
// declarations from dynet/dim.h //
///////////////////////////////////
Expand Down Expand Up @@ -322,9 +329,12 @@ private:

struct ParameterStorageBase {
virtual void scale_parameters(float a) = 0;
virtual void scale_gradient(float a) = 0;
virtual void zero() = 0;
virtual void squared_l2norm(float* sqnorm) const = 0;
virtual void g_squared_l2norm(float* sqnorm) const = 0;
virtual bool is_updated() const = 0;
virtual bool has_grad() const = 0;
virtual size_t size() const = 0;
virtual ~ParameterStorageBase();
};
Expand Down Expand Up @@ -385,11 +395,16 @@ class ParameterCollection {
float gradient_l2_norm() const;
void reset_gradient();

Parameter add_parameters(const Dim& d, float scale = 0.0f);
Parameter add_parameters(const Dim& d, const ParameterInit & init);
LookupParameter add_lookup_parameters(unsigned n, const Dim& d);
LookupParameter add_lookup_parameters(unsigned n, const Dim& d, const ParameterInit & init);

Parameter add_parameters(const Dim& d, float scale = 0.0f,
const std::string & name = "", Device *device = dynet::default_device);
Parameter add_parameters(const Dim& d, Device *device);
Parameter add_parameters(const Dim& d, const std::string & name, Device *device = dynet::default_device);
Parameter add_parameters(const Dim& d, const ParameterInit & init,
const std::string & name = "", Device *device = dynet::default_device);
LookupParameter add_lookup_parameters(unsigned n, const Dim& d,
const std::string & name = "", Device *device = dynet::default_device);
LookupParameter add_lookup_parameters(unsigned n, const Dim& d, const ParameterInit & init,
const std::string & name = "", Device *device = dynet::default_device);
void project_weights(float radius = 1.0f);
void set_weight_decay_lambda(float lambda);

Expand Down Expand Up @@ -434,6 +449,7 @@ struct Expression {
ComputationGraph *pg;
VariableIndex i;
Expression(ComputationGraph *pg, VariableIndex i) : pg(pg), i(i) { };
std::string get_device_name();
const Tensor& value();
const Dim& dim() const { return pg->get_dimension(i); }
};
Expand All @@ -448,10 +464,13 @@ Expression f(const T& xs, const T1& arg1);

/* INPUT OPERATIONS */

Expression input(ComputationGraph& g, real s);
Expression input(ComputationGraph& g, const real *ps);
Expression input(ComputationGraph& g, const Dim& d, const std::vector<float>* pdata);
Expression input(ComputationGraph& g, const Dim& d, const std::vector<unsigned int>& ids, const std::vector<float>& data, float defdata = 0.f);
Expression input(ComputationGraph& g, real s, Device *device = dynet::default_device);
Expression input(ComputationGraph& g, const real *ps, Device *device = dynet::default_device);
Expression input(ComputationGraph& g, const Dim& d, const std::vector<float>& data, Device *device = dynet::default_device);
// Expression input(ComputationGraph& g, const Dim& d, const std::vector<float>* pdata, Device *device = dynet::default_device);
Expression input(ComputationGraph& g, const Dim& d, const std::vector<unsigned int>& ids, const std::vector<float>& data, float defdata = 0.f, Device *device = dynet::default_device);
Expression one_hot(ComputationGraph& g, unsigned int d, unsigned int idx, Device *device = dynet::default_device);
Expression one_hot(ComputationGraph& g, unsigned int d, const std::vector<unsigned int>& ids, Device *device = dynet::default_device);
Expression parameter(ComputationGraph& g, Parameter p);
Expression parameter(ComputationGraph& g, LookupParameter lp);
Expression const_parameter(ComputationGraph& g, Parameter p);
Expand All @@ -465,14 +484,14 @@ Expression lookup(ComputationGraph& g, LookupParameter p, const std::vector<unsi
Expression const_lookup(ComputationGraph& g, LookupParameter p, const std::vector<unsigned>& indices);
//Expression const_lookup(ComputationGraph& g, LookupParameter p, const std::vector<unsigned>* pindices);

Expression zeros(ComputationGraph& g, const Dim& d);
Expression zeroes(ComputationGraph& g, const Dim& d);
Expression ones(ComputationGraph& g, const Dim& d);
Expression constant(ComputationGraph& g, const Dim& d, float val);
Expression random_normal(ComputationGraph& g, const Dim& d);
Expression random_bernoulli(ComputationGraph& g, const Dim& d, real p, real scale = 1.0f);
Expression random_uniform(ComputationGraph& g, const Dim& d, real left, real right);
Expression random_gumbel(ComputationGraph& g, const Dim& d, real mu = 0.0, real beta = 1.0);
Expression zeros(ComputationGraph& g, const Dim& d, Device *device = dynet::default_device);
Expression zeroes(ComputationGraph& g, const Dim& d, Device *device = dynet::default_device);
Expression ones(ComputationGraph& g, const Dim& d, Device *device = dynet::default_device);
Expression constant(ComputationGraph& g, const Dim& d, float val, Device *device = dynet::default_device);
Expression random_normal(ComputationGraph& g, const Dim& d, float mean=0.f, float stddev=1.0, Device *device = dynet::default_device);
Expression random_bernoulli(ComputationGraph& g, const Dim& d, real p, real scale = 1.0f, Device *device = dynet::default_device);
Expression random_uniform(ComputationGraph& g, const Dim& d, real left, real right, Device *device = dynet::default_device);
Expression random_gumbel(ComputationGraph& g, const Dim& d, real mu = 0.0, real beta = 1.0, Device *device = dynet::default_device);

/* ARITHMETIC OPERATIONS */

Expand Down Expand Up @@ -677,6 +696,8 @@ Expression trace_of_product(const Expression& x, const Expression& y);
Expression layer_norm(const Expression& x, const Expression& g, const Expression& b);
Expression weight_norm(const Expression& w, const Expression& g);

Expression to_device(const Expression & x, Device *device);

/////////////////////////////////////
// declarations from dynet/dynet.h //
/////////////////////////////////////
Expand Down Expand Up @@ -770,6 +791,7 @@ class Device {
Device& operator=(const Device&) = delete;
virtual ~Device();
public:
void reset_rng(unsigned seed) {};
int device_id;
DeviceType type;
MemAllocator* mem;
Expand All @@ -785,6 +807,36 @@ class Device {

extern Device* default_device; // where parameters go by default

class DeviceManager final {
public:
DeviceManager();
~DeviceManager();

void clear();

void add(Device* d);

Device* get(size_t i) { return devices[i]; }

size_t num_devices() const { return devices.size(); }

const std::vector<Device*>& get_devices() const { return devices; }

Device* get_global_device(const std::string & name);

// no copying allowed
DeviceManager(const DeviceManager &) = delete;
void operator=(const DeviceManager &) = delete;

private:
std::vector<Device*> devices;
std::unordered_map<std::string, Device*> devices_map;
};

DeviceManager* get_device_manager();

inline void show_pool_mem_info();

////////////////////////////////////////
// declarations from dynet/training.h //
////////////////////////////////////////
Expand Down Expand Up @@ -1233,12 +1285,10 @@ struct DynetParams {
int profiling = 0; /**< Whether to show profiling info or not */
bool shared_parameters = false; /**< TO DOCUMENT */

#ifdef SWIG_USE_CUDA
bool ngpus_requested = false; /**< GPUs requested by number */
bool ids_requested = false; /**< GPUs requested by ids */
int requested_gpus = -1; /**< Number of requested GPUs */
std::vector<int> gpu_mask; /**< List of required GPUs by ids */
#endif
};


Expand Down
2 changes: 1 addition & 1 deletion contrib/swig/project/assembly.sbt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.14.5")
addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.14.9")
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ package edu.cmu.dynet
object ComputationGraph {
private[dynet] var cg: internal.ComputationGraph = internal.ComputationGraph.getNew
var version: Long = 0L
private var defaultDevice: internal.Device = internal.dynet_swig.getDefault_device()
private val defaultDevice: internal.Device = internal.dynet_swig.getDefault_device()

/** Gets rid of the singleton Computation Graph and replaces it with a fresh one. Increments
* `version` to make sure we don't use any stale expressions.
Expand Down
15 changes: 15 additions & 0 deletions contrib/swig/src/main/scala/edu/cmu/dynet/Device.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package edu.cmu.dynet

object Device {
def apply(str:String): internal.Device = {
if(str == "" || str == "default") internal.dynet_swig.getDefault_device
else DeviceManager.getGlobalDevice(str)
}

lazy val default: internal.Device = internal.dynet_swig.getDefault_device

lazy val available: Vector[internal.Device] = {
val tmp = for(l <- 0L until DeviceManager.numDevices()) yield DeviceManager.get(l)
tmp.toVector
}
}
17 changes: 17 additions & 0 deletions contrib/swig/src/main/scala/edu/cmu/dynet/DeviceManager.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package edu.cmu.dynet

object DeviceManager {
private[dynet] val dm: internal.DeviceManager = internal.dynet_swig.get_device_manager()

def add(d: internal.Device): Unit = dm.add(d)

def get(l: Long): internal.Device = dm.get(l)

def numDevices(): Long = dm.num_devices()

def getGlobalDevice(name: String): internal.Device = dm.get_global_device(name)

def getDefaultDevice: internal.Device = internal.dynet_swig.getDefault_device

def showMemPoolInfo(): Unit = internal.dynet_swig.show_pool_mem_info()
}
6 changes: 3 additions & 3 deletions contrib/swig/src/main/scala/edu/cmu/dynet/Dim.scala
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class Dim private[dynet] (private[dynet] val dim: internal.Dim) {
def truncate(): Dim = new Dim(dim.truncate())
def singleBatch(): Dim = new Dim(dim.single_batch())

def resize(i: Long) = dim.resize(i)
def resize(i: Long): Unit = dim.resize(i)
def nDims(): Long = dim.ndims()
def rows(): Long = dim.rows()
def cols(): Long = dim.cols()
Expand All @@ -31,15 +31,15 @@ class Dim private[dynet] (private[dynet] val dim: internal.Dim) {
/** We override `equals` so that `Dim` objects should be equal whenever all of their dimension
* sizes match.
*/
override def equals(that: Any) = that match {
override def equals(that: Any): Boolean = that match {
case that: Dim => dim == that.dim
case _ => false
}
override def hashCode(): Int = dim.hashCode()

override def toString: String = "Dim(" + (0 until nDims.toInt).map(get(_)).mkString(", ") + ")"

def debugString(): String = s"(Dim: ${size} ${nDims} ${(0 until nDims.toInt).map(get(_))} )"
def debugString(): String = s"(Dim: $size $nDims ${(0 until nDims.toInt).map(get(_))} )"
}

/** Factory for [[edu.cmu.dynet.Dim]] instances. */
Expand Down
Loading