I'm trying to write an S4 class that specifically returns a numeric vector of the same length as the input. I think I'm close; the problem I'm having now is that I can only create new classes from functions that live in my GlobalEnv.
library(S4Vectors)
setClass("TransFunc", contains = c("function"), prototype = function(x) x)
TransFunc <- function(x) {
if (missing(x)) return(new("TransFunc"))
new2("TransFunc", x)
}
.TransFunc.validity <- function(object) {
msg <- NULL
if (length(formals(object)) > 1) {
msg <- c(msg, "TransFunc must only have one argument.")
}
res1 <- object(1:5)
res2 <- object(1:6)
if (length(res1) != 5 || length(res2) != 6) {
msg <- c(msg, "TransFunc output length must equal input length.")
}
if (!class(res1) %in% c("numeric", "integer")) {
msg <- c(msg, "TransFunc output must be numeric for numeric inputs.")
}
if (is.null(msg)) return(TRUE)
msg
}
setValidity2(Class = "TransFunc", method = .TransFunc.validity)
mysqrt <- TransFunc(function(x) sqrt(x))
mysqrt <- TransFunc(sqrt) ## Errors... why??
## Error in initialize(value, ...) :
## 'initialize' method returned an object of class “function” instead
## of the required class “TransFunc”
The benefit to having a class inherit from function directly is the ability to use them as regular functions:
mysqrt(1:5)
## [1] 1.000000 1.414214 1.732051 2.000000 2.236068
body(mysqrt) <- expression(sqrt(x)^2)
mysqrt(1:10)
## [1] 1 2 3 4 5 6 7 8 9 10
Why does it error when passing functions outside the global env?
CodePudding user response:
It does not work for sqrt because sqrt is primitive.
I am not aware of any functions that take only one argument and aren't primitive. Therefore I cut your validity down to demonstrate how your code works with other functions from the preloaded packages:
#using your class definition and counstructor
.TransFunc.validity <- function(object) {
msg <- NULL
res1 <- object(1:5)
if (!class(res1) %in% c("numeric", "integer")) {
msg <- c(msg, "TransFunc output must be numeric for numeric inputs.")
}
if (is.null(msg)) return(TRUE)
msg
}
setValidity2(Class = "TransFunc", method = .TransFunc.validity)
Here are the results for the default version of mean
mymean <- TransFunc(mean.default)
mymean(1:5)
[1] 3
Here is a workaround by modifying initialize for your class to catch primitives and turn them into closures:
#I modified the class definition to use slots instead of prototype
setClass("TransFunc", contains = c("function"))
TransFunc <- function(x) {
if (missing(x)) return(new("TransFunc"))
new2("TransFunc", x)
}
# Keeping your validity I changed initilalize to:
setMethod("initialize", "TransFunc",
function(.Object, .Data = function(x) x , ...) {
if(typeof(.Data) %in% c("builtin", "special"))
.Object <- callNextMethod(.Object, function(x) return(.Data(x)),...)
else
.Object <- callNextMethod(.Object, .Data, ...)
.Object
})
I got the following results
mysqrt <- TransFunc(sqrt)
mysqrt(1:5)
[1] 1.000000 1.414214 1.732051 2.000000 2.236068
EDIT:
in the comments @ekoam proposes a more general version of initilaize for your class:
setMethod("initialize", "TransFunc", function(.Object, ...)
{maybe_transfunc <- callNextMethod();
if (is.primitive(maybe_transfunc))
[email protected] <- maybe_transfunc
else .Object <- maybe_transfunc;
.Object})
