API Reference - Models - RandomNetworkDistillation
RandomNetworkDistillation is a neural network for producing internal rewards to encourage exploration. Requires neural network as your model.
Constructors
new()
Create new model object. If any of the arguments are nil, default argument values for that argument will be used.
RandomNetworkDistillation.new(): RandomNetworkDistillationObject
Returns
- Model: The generated model object.
Functions
generate()
RandomNetworkDistillation:generate(featureTensor: tensor): tensor
Parameters
- featureTensor: The tensor containing all the features.
Returns
- outputTensor: The tensor generated by the model from the given feature tensor.
setModel()
RandomNetworkDistillation:setModel(Model: ModelObject)
Parameters
- Model: The model to be used by the RandomNetworkDistillation object.
getModel()
RandomNetworkDistillation:setModel(): ModelObject
Returns
- Model: The model that is used by the RandomNetworkDistillation object.
getTargetWeightTensorArray()
Gets the target weight tensor array from the network.
RandomNetworkDistillation:getTargetWeightTensorArray(doNotDeepCopy: boolean): WeightTensorArray
Parameters
- doNotDeepCopy: Set whether or not to deep copy the weight tensor array.
Returns
- TargetWeightTensorArray: Target network weight tensor array to be used for predictor network training.
getPredictorWeightTensorArray()
Gets the predictor weight tensor array from the network.
RandomNetworkDistillation:getPredictorWeightTensorArray(doNotDeepCopy: boolean): WeightTensorArray
Parameters
- doNotDeepCopy: Set whether or not to deep copy the weight tensor array.
Returns
- PredictorWeightTensorArray: Target network weight tensor array to be used for predictor network training.
setTargetWeightTensorArray()
Set the target weight tensor array to the network
RandomNetworkDistillation:setTargetWeightTensorArray(TargetWeightTensorArray: WeightTensorArray, doNotDeepCopy: boolean)
Parameters
-
TargetWeightTensorArray: Target network weight tensor array to be used for predictor network training.
-
doNotDeepCopy: Set whether or not to deep copy the weight tensor array.
setPredictorWeightTensorArray()
Set the predictor weight tensor array to the network
RandomNetworkDistillation:setPredictorWeightTensorArray(PredictorWeightTensorArray: WeightTensorArray, doNotDeepCopy: boolean)
Parameters
-
PredictorWeightTensorArray: Predictor network weight tensor array to be trained so that it tries to match up with target network model parameters.
-
doNotDeepCopy: Set whether or not to deep copy the weight tensor array.