@@ -322,7 +322,7 @@ class BitmapTexture final : public Texture<Float, Spectrum> {
322322 m_wrap_mode,
323323 m_raw,
324324 m_accel,
325- tensor);
325+ std::move ( tensor) );
326326 }
327327
328328private:
@@ -367,35 +367,37 @@ class BitmapTextureImpl : public Texture<Float, Spectrum> {
367367 using StoredTensorXf = dr::replace_scalar_t <TensorXf, StoredScalar>;
368368 using StoredTexture2f = dr::Texture<StoredType, 2 >;
369369
370+ template <typename Tensor>
370371 BitmapTextureImpl (const Properties &props,
371- const std::string& name,
372- const ScalarTransform3f& transform,
373- dr::FilterMode filter_mode,
374- dr::WrapMode wrap_mode,
375- bool raw,
376- bool accel,
377- StoredTensorXf & tensor) :
372+ const std::string& name,
373+ const ScalarTransform3f& transform,
374+ dr::FilterMode filter_mode,
375+ dr::WrapMode wrap_mode,
376+ bool raw,
377+ bool accel,
378+ Tensor& & tensor) :
378379 Texture (props),
379380 m_name (name),
380381 m_transform (transform),
381382 m_accel (accel),
382- m_raw (raw),
383- m_texture (tensor, accel, accel, filter_mode, wrap_mode) {
383+ m_raw (raw) {
384384
385385 /* Compute mean without migrating texture data
386386 i.e. Avoid call to m_texture.tensor() that triggers migration.
387387 For CUDA-variants, ideally want to solely keep data as CUDA texture
388388 */
389389 rebuild_internals (tensor, true , false );
390+
391+ m_texture = StoredTexture2f (std::forward<Tensor>(tensor), accel, accel,
392+ filter_mode, wrap_mode);
390393 }
391394
392395 void traverse (TraversalCallback *callback) override {
393396 callback->put_parameter (" data" , m_texture.tensor (), +ParamFlags::Differentiable);
394397 callback->put_parameter (" to_uv" , m_transform, +ParamFlags::NonDifferentiable);
395398 }
396399
397- void
398- parameters_changed (const std::vector<std::string> &keys = {}) override {
400+ void parameters_changed (const std::vector<std::string> &keys = {}) override {
399401 if (keys.empty () || string::contains (keys, " data" )) {
400402 const size_t channels = m_texture.shape ()[2 ];
401403 if (channels != 1 && channels != 3 )
@@ -801,13 +803,14 @@ class BitmapTextureImpl : public Texture<Float, Spectrum> {
801803 if (m_transform != ScalarTransform3f ())
802804 dr::make_opaque (m_transform);
803805
804- size_t pixel_count = ( size_t ) dr::prod ( resolution () );
805- const size_t channels = m_texture. shape ()[ 2 ];
806- bool range_issue = false ;
806+ const dr::vector< size_t > &shape = tensor. shape ( );
807+ size_t pixel_count = shape[ 0 ] * shape[ 1 ],
808+ channels = shape[ 3 ] ;
807809
810+ bool range_issue = false ;
808811 using FloatStorage = DynamicBuffer<Float>;
809812 using StoredTypeArray= DynamicBuffer<StoredType>;
810- FloatStorage values = dr::empty<FloatStorage>(pixel_count) ;
813+ FloatStorage values;
811814
812815 if (channels == 3 ) {
813816 if constexpr (dr::is_jit_v<Float>) {
@@ -824,7 +827,11 @@ class BitmapTextureImpl : public Texture<Float, Spectrum> {
824827 values = luminance (colors_fl);
825828 } else {
826829 StoredScalar* ptr = (StoredScalar*) tensor.data ();
827- ScalarFloat *out = values.data (), mean = 0 ;
830+ ScalarFloat *out = nullptr , mean = 0 ;
831+ if (init_distr) {
832+ values = dr::empty<FloatStorage>(pixel_count);
833+ out = values.data ();
834+ }
828835
829836 for (size_t i = 0 ; i < pixel_count; ++i) {
830837 Color3f col (ptr[0 ], ptr[1 ], ptr[2 ]);
@@ -836,7 +843,8 @@ class BitmapTextureImpl : public Texture<Float, Spectrum> {
836843 else
837844 lum = luminance (col);
838845
839- *out++ = lum;
846+ if (init_distr)
847+ *out++ = lum;
840848 mean += lum;
841849 range_issue |= lum < 0 || lum > 1 ;
842850 }
@@ -848,10 +856,15 @@ class BitmapTextureImpl : public Texture<Float, Spectrum> {
848856 values = tensor.array ();
849857 } else {
850858 StoredScalar* ptr = (StoredScalar*) tensor.data ();
851- ScalarFloat *out = values.data (), mean = 0 ;
859+ ScalarFloat *out = nullptr , mean = 0 ;
860+ if (init_distr) {
861+ values = dr::empty<FloatStorage>(pixel_count);
862+ out = values.data ();
863+ }
852864 for (size_t i = 0 ; i < pixel_count; ++i) {
853865 ScalarFloat value = ptr[i];
854- *out++ = value;
866+ if (init_distr)
867+ *out++ = value;
855868 m_mean += value;
856869 range_issue |= value < 0 || value > 1 ;
857870 }
0 commit comments