Skip to content

Commit 5e639d3

Browse files
committed
Use precompiled type if availableAlso remove all references to jl_current_module
1 parent 9b1e8dd commit 5e639d3

File tree

5 files changed

+128
-30
lines changed

5 files changed

+128
-30
lines changed

examples/types.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ struct World
4444
~World() { std::cout << "Destroying World with message " << msg << std::endl; }
4545
};
4646

47+
struct Array {};
48+
4749
struct NonCopyable
4850
{
4951
NonCopyable() {}
@@ -101,7 +103,7 @@ void call_testype_function()
101103
jlcxx::JuliaFunction("julia_test_func")(result);
102104
}
103105

104-
enum CppEnum
106+
enum MyEnum
105107
{
106108
EnumValA,
107109
EnumValB
@@ -123,7 +125,7 @@ struct NullableStruct {};
123125

124126
namespace jlcxx
125127
{
126-
template<> struct IsBits<cpp_types::CppEnum> : std::true_type {};
128+
template<> struct IsBits<cpp_types::MyEnum> : std::true_type {};
127129
template<typename T> struct IsSmartPointerType<cpp_types::MySmartPointer<T>> : std::true_type { };
128130
template<typename T> struct ConstructorPointerType<cpp_types::MySmartPointer<T>> { typedef std::shared_ptr<T> type; };
129131
}
@@ -143,6 +145,8 @@ JLCXX_MODULE define_julia_module(jlcxx::Module& types)
143145
.method("greet", &World::greet)
144146
.method("greet_lambda", [] (const World& w) { return w.greet(); } );
145147

148+
types.add_type<Array>("Array");
149+
146150
types.method("world_factory", []()
147151
{
148152
return new World("factory hello");
@@ -230,10 +234,10 @@ JLCXX_MODULE define_julia_module(jlcxx::Module& types)
230234
.method("greet", &ConstPtrConstruct::greet);
231235

232236
// Enum
233-
types.add_bits<CppEnum>("CppEnum", jlcxx::julia_type("CppEnum"));
237+
types.add_bits<MyEnum>("MyEnum", jlcxx::julia_type("CppEnum"));
234238
types.set_const("EnumValA", EnumValA);
235239
types.set_const("EnumValB", EnumValB);
236-
types.method("enum_to_int", [] (const CppEnum e) { return static_cast<int>(e); });
240+
types.method("enum_to_int", [] (const MyEnum e) { return static_cast<int>(e); });
237241
types.method("get_enum_b", [] () { return EnumValB; });
238242

239243
types.add_type<Foo>("Foo")

include/jlcxx/module.hpp

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
namespace jlcxx
1818
{
1919

20-
/// Compatibility between 0.6 and 0.7
20+
/// Wrappers for creating new datatype
2121
JLCXX_API jl_datatype_t* new_datatype(jl_sym_t *name,
2222
jl_module_t* module,
2323
jl_datatype_t *super,
@@ -26,6 +26,11 @@ JLCXX_API jl_datatype_t* new_datatype(jl_sym_t *name,
2626
int abstract, int mutabl,
2727
int ninitialized);
2828

29+
JLCXX_API jl_datatype_t* new_bitstype(jl_sym_t *name,
30+
jl_module_t* module,
31+
jl_datatype_t *super,
32+
jl_svec_t *parameters, const size_t nbits);
33+
2934
/// Some helper functions
3035
namespace detail
3136
{
@@ -491,13 +496,7 @@ class JLCXX_API Module
491496
return module_name(m_jl_mod);
492497
}
493498

494-
void bind_constants(jl_module_t* mod)
495-
{
496-
for(auto& dt_pair : m_jl_constants)
497-
{
498-
jl_set_const(mod, jl_symbol(dt_pair.first.c_str()), dt_pair.second);
499-
}
500-
}
499+
void bind_constants(jl_module_t* mod);
501500

502501
jl_datatype_t* get_julia_type(const char* name)
503502
{
@@ -1049,13 +1048,7 @@ void Module::add_bits(const std::string& name, JLSuperT* super)
10491048
static_assert(std::is_scalar<T>::value, "Bits types must be a scalar type");
10501049
jl_svec_t* params = is_parametric ? parameter_list<T>()() : jl_emptysvec;
10511050
JL_GC_PUSH1(&params);
1052-
#if JULIA_VERSION_MAJOR == 0 && JULIA_VERSION_MINOR < 6
1053-
jl_datatype_t* dt = jl_new_bitstype((jl_value_t*)jl_symbol(name.c_str()), (jl_datatype_t*)super, params, 8*sizeof(T));
1054-
#elif JULIA_VERSION_MAJOR == 0 && JULIA_VERSION_MINOR < 7
1055-
jl_datatype_t* dt = jl_new_primitivetype((jl_value_t*)jl_symbol(name.c_str()), (jl_datatype_t*)super, params, 8*sizeof(T));
1056-
#else
1057-
jl_datatype_t* dt = jl_new_primitivetype((jl_value_t*)jl_symbol(name.c_str()), m_jl_mod, (jl_datatype_t*)super, params, 8*sizeof(T));
1058-
#endif
1051+
jl_datatype_t* dt = new_bitstype(jl_symbol(name.c_str()), m_jl_mod, (jl_datatype_t*)super, params, 8*sizeof(T));
10591052
protect_from_gc(dt);
10601053
JL_GC_POP();
10611054
detail::dispatch_set_julia_type<T, is_parametric>()(dt);
@@ -1085,6 +1078,7 @@ class JLCXX_API ModuleRegistry
10851078
return m_modules.find(jmod) != m_modules.end();
10861079
}
10871080

1081+
bool has_current_module() { return m_current_module != nullptr; }
10881082
Module& current_module();
10891083
void reset_current_module() { m_current_module = nullptr; }
10901084

include/jlcxx/smart_pointers.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,8 @@ struct BaseMapping
8787
{
8888
};
8989

90-
template<template<typename...> class PtrT, typename PointeeT, typename OtherPtrT>
91-
struct BaseMapping<PtrT<PointeeT>, OtherPtrT>
90+
template<template<typename...> class PtrT, typename PointeeT, typename OtherPtrT, typename... ExtraArgs>
91+
struct BaseMapping<PtrT<PointeeT, ExtraArgs...>, OtherPtrT>
9292
{
9393
template<bool B, typename DummyT=void>
9494
struct ConditionalConstructFromOther

src/functions.cpp

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include "jlcxx/functions.hpp"
2+
#include "jlcxx/module.hpp"
23

34
// This header provides helper functions to call Julia functions from C++
45

@@ -7,10 +8,27 @@ namespace jlcxx
78

89
JuliaFunction::JuliaFunction(const std::string& name, const std::string& module_name)
910
{
10-
jl_module_t* mod = module_name.empty() ? jl_current_module : (jl_module_t*)jl_get_global(jl_current_module, jl_symbol(module_name.c_str()));
11+
jl_module_t* mod = nullptr;
12+
jl_module_t* current_mod = nullptr;
13+
if(registry().has_current_module())
14+
{
15+
current_mod = registry().current_module().julia_module();
16+
}
17+
if(!module_name.empty())
18+
{
19+
mod = (jl_module_t*)jl_get_global(jl_main_module, jl_symbol(module_name.c_str()));
20+
if(mod == nullptr && current_mod != nullptr)
21+
{
22+
mod = (jl_module_t *)jl_get_global(current_mod, jl_symbol(module_name.c_str()));
23+
}
24+
if(mod == nullptr)
25+
{
26+
throw std::runtime_error("Could not find module " + module_name + " when looking up function " + name);
27+
}
28+
}
1129
if(mod == nullptr)
1230
{
13-
throw std::runtime_error("Could not find module " + module_name + " when looking up function " + module_name);
31+
mod = current_mod == nullptr ? jl_main_module : current_mod;
1432
}
1533

1634
m_function = jl_get_function(mod, name.c_str());

src/jlcxx.cpp

Lines changed: 89 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,14 @@ int_t Module::store_pointer(void *ptr)
5555
return m_pointer_array.size();
5656
}
5757

58+
void Module::bind_constants(jl_module_t* mod)
59+
{
60+
for(auto& dt_pair : m_jl_constants)
61+
{
62+
jl_set_const(mod, jl_symbol(dt_pair.first.c_str()), dt_pair.second);
63+
}
64+
}
65+
5866
Module &ModuleRegistry::create_module(jl_module_t* jmod)
5967
{
6068
if(jmod == nullptr)
@@ -91,8 +99,43 @@ JLCXX_API ModuleRegistry& registry()
9199

92100
JLCXX_API jl_value_t* julia_type(const std::string& name, const std::string& module_name)
93101
{
94-
const auto mods = {module_name.empty() ? nullptr : (jl_module_t*)jl_get_global(jl_current_module, jl_symbol(module_name.c_str())), jl_base_module, g_cxxwrap_module, jl_current_module, jl_current_module->parent};
95-
std::string found_type;
102+
std::vector<jl_module_t*> mods;
103+
mods.reserve(6);
104+
jl_module_t* current_mod = registry().has_current_module() ? registry().current_module().julia_module() : nullptr;
105+
if(!module_name.empty())
106+
{
107+
jl_sym_t* modsym = jl_symbol(module_name.c_str());
108+
jl_module_t* found_mod = nullptr;
109+
if(current_mod != nullptr)
110+
{
111+
found_mod = (jl_module_t*)jl_get_global(current_mod, modsym);
112+
}
113+
if(found_mod == nullptr)
114+
{
115+
found_mod = (jl_module_t*)jl_get_global(jl_main_module, jl_symbol(module_name.c_str()));
116+
}
117+
if(found_mod != nullptr)
118+
{
119+
mods.push_back(found_mod);
120+
}
121+
else
122+
{
123+
throw std::runtime_error("Failed to find module " + module_name);
124+
}
125+
}
126+
else
127+
{
128+
if (current_mod != nullptr)
129+
{
130+
mods.push_back(current_mod);
131+
}
132+
mods.push_back(jl_main_module);
133+
mods.push_back(jl_base_module);
134+
mods.push_back(g_cxxwrap_module);
135+
mods.push_back(jl_top_module);
136+
}
137+
138+
std::string found_type = "null";
96139
for(jl_module_t* mod : mods)
97140
{
98141
if(mod == nullptr)
@@ -170,6 +213,24 @@ std::wstring ConvertToCpp<std::wstring, false, false, false>::operator()(jl_valu
170213
return std::wstring(arr.data(), arr.size());
171214
}
172215

216+
static constexpr const char* dt_prefix = "__cxxwrap_dt_";
217+
218+
jl_datatype_t* existing_datatype(jl_module_t* mod, jl_sym_t* name)
219+
{
220+
const std::string prefixed_name = dt_prefix + symbol_name(name);
221+
jl_value_t* found_dt = jl_get_global(mod, jl_symbol(prefixed_name.c_str()));
222+
if(found_dt == nullptr || !jl_is_datatype(found_dt))
223+
{
224+
return nullptr;
225+
}
226+
return (jl_datatype_t*)found_dt;
227+
}
228+
229+
void set_internal_constant(jl_module_t* mod, jl_datatype_t* dt, const std::string& prefixed_name)
230+
{
231+
jl_set_const(mod, jl_symbol(prefixed_name.c_str()), (jl_value_t*)dt);
232+
}
233+
173234
JLCXX_API jl_datatype_t* new_datatype(jl_sym_t *name,
174235
jl_module_t* module,
175236
jl_datatype_t *super,
@@ -182,11 +243,32 @@ JLCXX_API jl_datatype_t* new_datatype(jl_sym_t *name,
182243
{
183244
throw std::runtime_error("null module when creating type");
184245
}
185-
#if JULIA_VERSION_MAJOR == 0 && JULIA_VERSION_MINOR < 7
186-
return jl_new_datatype(name, super, parameters, fnames, ftypes, abstract, mutabl, ninitialized);
187-
#else
188-
return jl_new_datatype(name, module, super, parameters, fnames, ftypes, abstract, mutabl, ninitialized);
189-
#endif
246+
jl_datatype_t* dt = existing_datatype(module, name);
247+
if(dt != nullptr)
248+
{
249+
return dt;
250+
}
251+
252+
dt = jl_new_datatype(name, module, super, parameters, fnames, ftypes, abstract, mutabl, ninitialized);
253+
set_internal_constant(module, dt, dt_prefix + symbol_name(name));
254+
return dt;
255+
}
256+
257+
JLCXX_API jl_datatype_t* new_bitstype(jl_sym_t *name,
258+
jl_module_t* module,
259+
jl_datatype_t *super,
260+
jl_svec_t *parameters, const size_t nbits)
261+
{
262+
assert(module != nullptr);
263+
jl_datatype_t* dt = existing_datatype(module, name);
264+
if(dt != nullptr)
265+
{
266+
return dt;
267+
}
268+
269+
dt = jl_new_primitivetype((jl_value_t*)name, module, super, parameters, nbits);
270+
set_internal_constant(module, dt, dt_prefix + symbol_name(name));
271+
return dt;
190272
}
191273

192274
}

0 commit comments

Comments
 (0)