diff --git a/ethosu/regor/common/buffer_view.hpp b/ethosu/regor/common/buffer_view.hpp index 0dc848e66a24db460951b44bdeb2ada5dda6f8ae..92c0dfc89038a7179845e21e20ed608be45434fb 100644 --- a/ethosu/regor/common/buffer_view.hpp +++ b/ethosu/regor/common/buffer_view.hpp @@ -118,7 +118,8 @@ private: Placement _placement = Placement::Remote; LocalStorage _localStorage; DeleteFunc _deleter = nullptr; - Hash128 _dataHash; + mutable Hash128 _dataHash; + mutable bool _invalidHash = true; public: Buffer(const Buffer &) = delete; @@ -327,9 +328,18 @@ public: } } - const Hash128 &Hash() const { return _dataHash; } + const Hash128 &Hash() const + { + if ( _invalidHash ) + { + Rehash(); + } + return _dataHash; + } + + void InvalidateHash() { _invalidHash = true; } - void Rehash() + void Rehash() const { if ( Size() > 0 ) { @@ -338,10 +348,8 @@ public: sizeStr += std::to_string(Size()); sizeStr += '>'; MD5 hash; - // Make sure the const overload of Data() is called - const uint8_t *data = std::as_const(*this).Data(); hash.Combine(reinterpret_cast(sizeStr.data()), int(sizeStr.size())); - hash.Combine(data, Size()); + hash.Combine(Data(), Size()); hash.Get(_dataHash); } else @@ -352,6 +360,7 @@ public: _dataHash.v32[0] = _dataHash.v32[1] = static_cast(ptr); _dataHash.v32[2] = _dataHash.v32[3] = static_cast(ptr >> 32); } + _invalidHash = false; } private: @@ -686,6 +695,7 @@ public: { assert(HasBuffer() && _strideBytes); auto *start = _buffer->Data() + _baseOffset; + _buffer->InvalidateHash(); return BufferWriter(_strideBytes, start, _elements); } diff --git a/ethosu/regor/compiler/scheduler_packing.cpp b/ethosu/regor/compiler/scheduler_packing.cpp index 60377ce9459f4161d36fe0ead377e1274691ef4c..2030d5a3cb89fc56e62645bed61f332a59776431 100644 --- a/ethosu/regor/compiler/scheduler_packing.cpp +++ b/ethosu/regor/compiler/scheduler_packing.cpp @@ -441,7 +441,6 @@ int SchedulerPacking::CanPack(const SchedulerOperation *schedOp, const Scheduler } // Do not pack persistent tensors with non persistent tensors - // if ( ifmTensor->isPersistent != prevOFM->isPersistent ) if ( prevOFM->isPersistent != nextOp->OFM()->tensor->isPersistent ) { return 0;